1 /*
2  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "main_functions.h"
18 
19 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
20 #include "constants.h"
21 #include "model.hpp"
22 #include "output_handler.hpp"
23 #include <tensorflow/lite/micro/micro_log.h>
24 #include <tensorflow/lite/micro/micro_interpreter.h>
25 #include <tensorflow/lite/micro/system_setup.h>
26 #include <tensorflow/lite/schema/schema_generated.h>
27 
28 /* Globals, used for compatibility with Arduino-style sketches. */
29 namespace {
30 	const tflite::Model *model = nullptr;
31 	tflite::MicroInterpreter *interpreter = nullptr;
32 	TfLiteTensor *input = nullptr;
33 	TfLiteTensor *output = nullptr;
34 	int inference_count = 0;
35 
36 	constexpr int kTensorArenaSize = 2000;
37 	uint8_t tensor_arena[kTensorArenaSize];
38 }  /* namespace */
39 
40 /* The name of this function is important for Arduino compatibility. */
setup(void)41 void setup(void)
42 {
43 	/* Map the model into a usable data structure. This doesn't involve any
44 	 * copying or parsing, it's a very lightweight operation.
45 	 */
46 	model = tflite::GetModel(g_model);
47 	if (model->version() != TFLITE_SCHEMA_VERSION) {
48 		MicroPrintf("Model provided is schema version %d not equal "
49 					"to supported version %d.",
50 					model->version(), TFLITE_SCHEMA_VERSION);
51 		return;
52 	}
53 
54 	/* This pulls in the operation implementations we need.
55 	 * NOLINTNEXTLINE(runtime-global-variables)
56 	 */
57 	static tflite::MicroMutableOpResolver <1> resolver;
58 	resolver.AddFullyConnected();
59 
60 	/* Build an interpreter to run the model with. */
61 	static tflite::MicroInterpreter static_interpreter(
62 		model, resolver, tensor_arena, kTensorArenaSize);
63 	interpreter = &static_interpreter;
64 
65 	/* Allocate memory from the tensor_arena for the model's tensors. */
66 	TfLiteStatus allocate_status = interpreter->AllocateTensors();
67 	if (allocate_status != kTfLiteOk) {
68 		MicroPrintf("AllocateTensors() failed");
69 		return;
70 	}
71 
72 	/* Obtain pointers to the model's input and output tensors. */
73 	input = interpreter->input(0);
74 	output = interpreter->output(0);
75 
76 	/* Keep track of how many inferences we have performed. */
77 	inference_count = 0;
78 }
79 
80 /* The name of this function is important for Arduino compatibility. */
loop(void)81 void loop(void)
82 {
83 	/* Calculate an x value to feed into the model. We compare the current
84 	 * inference_count to the number of inferences per cycle to determine
85 	 * our position within the range of possible x values the model was
86 	 * trained on, and use this to calculate a value.
87 	 */
88 	float position = static_cast < float > (inference_count) /
89 			 static_cast < float > (kInferencesPerCycle);
90 	float x = position * kXrange;
91 
92 	/* Quantize the input from floating-point to integer */
93 	int8_t x_quantized = x / input->params.scale + input->params.zero_point;
94 	/* Place the quantized input in the model's input tensor */
95 	input->data.int8[0] = x_quantized;
96 
97 	/* Run inference, and report any error */
98 	TfLiteStatus invoke_status = interpreter->Invoke();
99 	if (invoke_status != kTfLiteOk) {
100 		MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
101 		return;
102 	}
103 
104 	/* Obtain the quantized output from model's output tensor */
105 	int8_t y_quantized = output->data.int8[0];
106 	/* Dequantize the output from integer to floating-point */
107 	float y = (y_quantized - output->params.zero_point) * output->params.scale;
108 
109 	/* Output the results. A custom HandleOutput function can be implemented
110 	 * for each supported hardware target.
111 	 */
112 	HandleOutput(x, y);
113 
114 	/* Increment the inference_counter, and reset it if we have reached
115 	 * the total number per cycle
116 	 */
117 	inference_count += 1;
118 	if (inference_count >= kInferencesPerCycle) inference_count = 0;
119 }
120