1 /*
2 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "main_functions.hpp"
18
19 #include "accelerometer_handler.hpp"
20 #include "constants.hpp"
21 #include "gesture_predictor.hpp"
22 #include "magic_wand_model_data.hpp"
23 #include "output_handler.hpp"
24 #include <tensorflow/lite/micro/micro_log.h>
25 #include <tensorflow/lite/micro/micro_interpreter.h>
26 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
27 #include <tensorflow/lite/schema/schema_generated.h>
28
29 /* Globals, used for compatibility with Arduino-style sketches. */
30 namespace {
31 const tflite::Model *model = nullptr;
32 tflite::MicroInterpreter *interpreter = nullptr;
33 TfLiteTensor *model_input = nullptr;
34 int input_length;
35
36 /* Create an area of memory to use for input, output, and intermediate arrays.
37 * The size of this will depend on the model you're using, and may need to be
38 * determined by experimentation.
39 */
40 constexpr int kTensorArenaSize = 60 * 1024;
41 uint8_t tensor_arena[kTensorArenaSize];
42 } /* namespace */
43
44 /* The name of this function is important for Arduino compatibility. */
setup(void)45 void setup(void)
46 {
47 /* Map the model into a usable data structure. This doesn't involve any
48 * copying or parsing, it's a very lightweight operation.
49 */
50 model = tflite::GetModel(g_magic_wand_model_data);
51 if (model->version() != TFLITE_SCHEMA_VERSION) {
52 MicroPrintf("Model provided is schema version %d not equal "
53 "to supported version %d.",
54 model->version(), TFLITE_SCHEMA_VERSION);
55 return;
56 }
57
58 /* Pull in only the operation implementations we need.
59 * This relies on a complete list of all the ops needed by this graph.
60 * An easier approach is to just use the AllOpsResolver, but this will
61 * incur some penalty in code space for op implementations that are not
62 * needed by this graph.
63 */
64 static tflite::MicroMutableOpResolver < 5 > micro_op_resolver; /* NOLINT */
65 micro_op_resolver.AddConv2D();
66 micro_op_resolver.AddDepthwiseConv2D();
67 micro_op_resolver.AddFullyConnected();
68 micro_op_resolver.AddMaxPool2D();
69 micro_op_resolver.AddSoftmax();
70
71 /* Build an interpreter to run the model with. */
72 static tflite::MicroInterpreter static_interpreter(
73 model, micro_op_resolver, tensor_arena, kTensorArenaSize);
74 interpreter = &static_interpreter;
75
76 /* Allocate memory from the tensor_arena for the model's tensors. */
77 interpreter->AllocateTensors();
78
79 /* Obtain pointer to the model's input tensor. */
80 model_input = interpreter->input(0);
81 if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
82 (model_input->dims->data[1] != 128) ||
83 (model_input->dims->data[2] != kChannelNumber) ||
84 (model_input->type != kTfLiteFloat32)) {
85 MicroPrintf("Bad input tensor parameters in model");
86 return;
87 }
88
89 input_length = model_input->bytes / sizeof(float);
90
91 TfLiteStatus setup_status = SetupAccelerometer();
92 if (setup_status != kTfLiteOk) {
93 MicroPrintf("Set up failed\n");
94 }
95 }
96
loop(void)97 void loop(void)
98 {
99 /* Attempt to read new data from the accelerometer. */
100 bool got_data =
101 ReadAccelerometer(model_input->data.f, input_length);
102
103 /* If there was no new data, wait until next time. */
104 if (!got_data) {
105 return;
106 }
107
108 /* Run inference, and report any error */
109 TfLiteStatus invoke_status = interpreter->Invoke();
110 if (invoke_status != kTfLiteOk) {
111 MicroPrintf("Invoke failed on index: %d\n", begin_index);
112 return;
113 }
114 /* Analyze the results to obtain a prediction */
115 int gesture_index = PredictGesture(interpreter->output(0)->data.f);
116
117 /* Produce an output */
118 HandleOutput(gesture_index);
119 }
120