1 /*
2  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "main_functions.hpp"
18 
19 #include "accelerometer_handler.hpp"
20 #include "constants.hpp"
21 #include "gesture_predictor.hpp"
22 #include "magic_wand_model_data.hpp"
23 #include "output_handler.hpp"
24 #include <tensorflow/lite/micro/micro_log.h>
25 #include <tensorflow/lite/micro/micro_interpreter.h>
26 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
27 #include <tensorflow/lite/schema/schema_generated.h>
28 
29 /* Globals, used for compatibility with Arduino-style sketches. */
30 namespace {
31 	const tflite::Model *model = nullptr;
32 	tflite::MicroInterpreter *interpreter = nullptr;
33 	TfLiteTensor *model_input = nullptr;
34 	int input_length;
35 
36 	/* Create an area of memory to use for input, output, and intermediate arrays.
37 	* The size of this will depend on the model you're using, and may need to be
38 	* determined by experimentation.
39 	*/
40 	constexpr int kTensorArenaSize = 60 * 1024;
41 	uint8_t tensor_arena[kTensorArenaSize];
42 } /* namespace */
43 
44 /* The name of this function is important for Arduino compatibility. */
setup(void)45 void setup(void)
46 {
47 	/* Map the model into a usable data structure. This doesn't involve any
48 	 * copying or parsing, it's a very lightweight operation.
49 	 */
50 	model = tflite::GetModel(g_magic_wand_model_data);
51 	if (model->version() != TFLITE_SCHEMA_VERSION) {
52 		MicroPrintf("Model provided is schema version %d not equal "
53 				    "to supported version %d.",
54 				    model->version(), TFLITE_SCHEMA_VERSION);
55 		return;
56 	}
57 
58 	/* Pull in only the operation implementations we need.
59 	 * This relies on a complete list of all the ops needed by this graph.
60 	 * An easier approach is to just use the AllOpsResolver, but this will
61 	 * incur some penalty in code space for op implementations that are not
62 	 * needed by this graph.
63 	 */
64 	static tflite::MicroMutableOpResolver < 5 > micro_op_resolver; /* NOLINT */
65 	micro_op_resolver.AddConv2D();
66 	micro_op_resolver.AddDepthwiseConv2D();
67 	micro_op_resolver.AddFullyConnected();
68 	micro_op_resolver.AddMaxPool2D();
69 	micro_op_resolver.AddSoftmax();
70 
71 	/* Build an interpreter to run the model with. */
72 	static tflite::MicroInterpreter static_interpreter(
73 		model, micro_op_resolver, tensor_arena, kTensorArenaSize);
74 	interpreter = &static_interpreter;
75 
76 	/* Allocate memory from the tensor_arena for the model's tensors. */
77 	interpreter->AllocateTensors();
78 
79 	/* Obtain pointer to the model's input tensor. */
80 	model_input = interpreter->input(0);
81 	if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
82 	    (model_input->dims->data[1] != 128) ||
83 	    (model_input->dims->data[2] != kChannelNumber) ||
84 	    (model_input->type != kTfLiteFloat32)) {
85 		MicroPrintf("Bad input tensor parameters in model");
86 		return;
87 	}
88 
89 	input_length = model_input->bytes / sizeof(float);
90 
91 	TfLiteStatus setup_status = SetupAccelerometer();
92 	if (setup_status != kTfLiteOk) {
93 		MicroPrintf("Set up failed\n");
94 	}
95 }
96 
loop(void)97 void loop(void)
98 {
99 	/* Attempt to read new data from the accelerometer. */
100 	bool got_data =
101 		ReadAccelerometer(model_input->data.f, input_length);
102 
103 	/* If there was no new data, wait until next time. */
104 	if (!got_data) {
105 		return;
106 	}
107 
108 	/* Run inference, and report any error */
109 	TfLiteStatus invoke_status = interpreter->Invoke();
110 	if (invoke_status != kTfLiteOk) {
111 		MicroPrintf("Invoke failed on index: %d\n", begin_index);
112 		return;
113 	}
114 	/* Analyze the results to obtain a prediction */
115 	int gesture_index = PredictGesture(interpreter->output(0)->data.f);
116 
117 	/* Produce an output */
118 	HandleOutput(gesture_index);
119 }
120