1 /*
2 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "main_functions.h"
18
19 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
20 #include "constants.h"
21 #include "model.hpp"
22 #include "output_handler.hpp"
23 #include <tensorflow/lite/micro/micro_log.h>
24 #include <tensorflow/lite/micro/micro_interpreter.h>
25 #include <tensorflow/lite/micro/system_setup.h>
26 #include <tensorflow/lite/schema/schema_generated.h>
27
28 /* Globals, used for compatibility with Arduino-style sketches. */
29 namespace {
30 const tflite::Model *model = nullptr;
31 tflite::MicroInterpreter *interpreter = nullptr;
32 TfLiteTensor *input = nullptr;
33 TfLiteTensor *output = nullptr;
34 int inference_count = 0;
35
36 constexpr int kTensorArenaSize = 2000;
37 uint8_t tensor_arena[kTensorArenaSize];
38 } /* namespace */
39
40 /* The name of this function is important for Arduino compatibility. */
setup(void)41 void setup(void)
42 {
43 /* Map the model into a usable data structure. This doesn't involve any
44 * copying or parsing, it's a very lightweight operation.
45 */
46 model = tflite::GetModel(g_model);
47 if (model->version() != TFLITE_SCHEMA_VERSION) {
48 MicroPrintf("Model provided is schema version %d not equal "
49 "to supported version %d.",
50 model->version(), TFLITE_SCHEMA_VERSION);
51 return;
52 }
53
54 /* This pulls in the operation implementations we need.
55 * NOLINTNEXTLINE(runtime-global-variables)
56 */
57 static tflite::MicroMutableOpResolver <1> resolver;
58 resolver.AddFullyConnected();
59
60 /* Build an interpreter to run the model with. */
61 static tflite::MicroInterpreter static_interpreter(
62 model, resolver, tensor_arena, kTensorArenaSize);
63 interpreter = &static_interpreter;
64
65 /* Allocate memory from the tensor_arena for the model's tensors. */
66 TfLiteStatus allocate_status = interpreter->AllocateTensors();
67 if (allocate_status != kTfLiteOk) {
68 MicroPrintf("AllocateTensors() failed");
69 return;
70 }
71
72 /* Obtain pointers to the model's input and output tensors. */
73 input = interpreter->input(0);
74 output = interpreter->output(0);
75
76 /* Keep track of how many inferences we have performed. */
77 inference_count = 0;
78 }
79
80 /* The name of this function is important for Arduino compatibility. */
loop(void)81 void loop(void)
82 {
83 /* Calculate an x value to feed into the model. We compare the current
84 * inference_count to the number of inferences per cycle to determine
85 * our position within the range of possible x values the model was
86 * trained on, and use this to calculate a value.
87 */
88 float position = static_cast < float > (inference_count) /
89 static_cast < float > (kInferencesPerCycle);
90 float x = position * kXrange;
91
92 /* Quantize the input from floating-point to integer */
93 int8_t x_quantized = x / input->params.scale + input->params.zero_point;
94 /* Place the quantized input in the model's input tensor */
95 input->data.int8[0] = x_quantized;
96
97 /* Run inference, and report any error */
98 TfLiteStatus invoke_status = interpreter->Invoke();
99 if (invoke_status != kTfLiteOk) {
100 MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
101 return;
102 }
103
104 /* Obtain the quantized output from model's output tensor */
105 int8_t y_quantized = output->data.int8[0];
106 /* Dequantize the output from integer to floating-point */
107 float y = (y_quantized - output->params.zero_point) * output->params.scale;
108
109 /* Output the results. A custom HandleOutput function can be implemented
110 * for each supported hardware target.
111 */
112 HandleOutput(x, y);
113
114 /* Increment the inference_counter, and reset it if we have reached
115 * the total number per cycle
116 */
117 inference_count += 1;
118 if (inference_count >= kInferencesPerCycle) inference_count = 0;
119 }
120