1 /*
2 * Copyright 2019-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include "inference_process.hpp"
8
9 #include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
10 #include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
11 #include <tensorflow/lite/micro/micro_log.h>
12 #include <tensorflow/lite/micro/micro_interpreter.h>
13 #include <tensorflow/lite/micro/micro_profiler.h>
14 #include <tensorflow/lite/schema/schema_generated.h>
15
16 #include <cmsis_compiler.h>
17 #include <inttypes.h>
18 #include <zephyr/kernel.h>
19
20 using namespace std;
21
22 namespace
23 {
copyOutput(const TfLiteTensor & src,InferenceProcess::DataPtr & dst)24 bool copyOutput(const TfLiteTensor &src, InferenceProcess::DataPtr &dst)
25 {
26 if (dst.data == nullptr) {
27 return false;
28 }
29
30 if (src.bytes > dst.size) {
31 printf("Tensor size mismatch (bytes): actual=%d, expected%d.\n", src.bytes,
32 dst.size);
33 return true;
34 }
35
36 copy(src.data.uint8, src.data.uint8 + src.bytes, static_cast<uint8_t *>(dst.data));
37 dst.size = src.bytes;
38
39 return false;
40 }
41
42 } /* namespace */
43
44 namespace InferenceProcess
45 {
DataPtr(void * _data,size_t _size)46 DataPtr::DataPtr(void *_data, size_t _size) : data(_data), size(_size)
47 {
48 }
49
invalidate()50 void DataPtr::invalidate()
51 {
52 #if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
53 SCB_InvalidateDCache_by_Addr(reinterpret_cast<uint32_t *>(data), size);
54 #endif
55 }
56
clean()57 void DataPtr::clean()
58 {
59 #if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
60 SCB_CleanDCache_by_Addr(reinterpret_cast<uint32_t *>(data), size);
61 #endif
62 }
63
InferenceJob()64 InferenceJob::InferenceJob()
65 {
66 }
67
InferenceJob(const string & _name,const DataPtr & _networkModel,const vector<DataPtr> & _input,const vector<DataPtr> & _output,const vector<DataPtr> & _expectedOutput)68 InferenceJob::InferenceJob(const string &_name, const DataPtr &_networkModel,
69 const vector<DataPtr> &_input, const vector<DataPtr> &_output,
70 const vector<DataPtr> &_expectedOutput)
71 : name(_name), networkModel(_networkModel), input(_input), output(_output),
72 expectedOutput(_expectedOutput)
73 {
74 }
75
invalidate()76 void InferenceJob::invalidate()
77 {
78 networkModel.invalidate();
79
80 for (auto &it : input) {
81 it.invalidate();
82 }
83
84 for (auto &it : output) {
85 it.invalidate();
86 }
87
88 for (auto &it : expectedOutput) {
89 it.invalidate();
90 }
91 }
92
clean()93 void InferenceJob::clean()
94 {
95 networkModel.clean();
96
97 for (auto &it : input) {
98 it.clean();
99 }
100
101 for (auto &it : output) {
102 it.clean();
103 }
104
105 for (auto &it : expectedOutput) {
106 it.clean();
107 }
108 }
109
runJob(InferenceJob & job)110 bool InferenceProcess::runJob(InferenceJob &job)
111 {
112 /* Get model handle and verify that the version is correct */
113 const tflite::Model *model = ::tflite::GetModel(job.networkModel.data);
114 if (model->version() != TFLITE_SCHEMA_VERSION) {
115 printf("Model schema version unsupported: version=%" PRIu32 ", supported=%d.\n",
116 model->version(), TFLITE_SCHEMA_VERSION);
117 return true;
118 }
119
120 /* Create the TFL micro interpreter */
121 #ifdef CONFIG_TAINT_BLOBS_TFLM_ETHOSU
122 tflite::MicroMutableOpResolver <1> resolver;
123 resolver.AddEthosU();
124 #else
125 tflite::MicroMutableOpResolver <4> resolver;
126 resolver.AddReshape();
127 resolver.AddConv2D();
128 resolver.AddFullyConnected();
129 resolver.AddSoftmax();
130 #endif
131 tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize);
132
133 /* Allocate tensors */
134 TfLiteStatus allocate_status = interpreter.AllocateTensors();
135 if (allocate_status != kTfLiteOk) {
136 printf("Failed to allocate tensors for inference. job=%p\n", &job);
137 return true;
138 }
139
140 if (job.input.size() != interpreter.inputs_size()) {
141 printf("Number of job and network inputs do not match. input=%zu, network=%zu\n",
142 job.input.size(), interpreter.inputs_size());
143 return true;
144 }
145
146 /* Copy input data */
147 for (size_t i = 0; i < interpreter.inputs_size(); ++i) {
148 const DataPtr &input = job.input[i];
149 const TfLiteTensor *tensor = interpreter.input(i);
150
151 if (input.size != tensor->bytes) {
152 printf("Input tensor size mismatch. index=%zu, input=%zu, network=%u\n", i,
153 input.size, tensor->bytes);
154 return true;
155 }
156
157 copy(static_cast<char *>(input.data), static_cast<char *>(input.data) + input.size,
158 tensor->data.uint8);
159 }
160
161 /* Run the inference */
162 TfLiteStatus invoke_status = interpreter.Invoke();
163 if (invoke_status != kTfLiteOk) {
164 printf("Invoke failed for inference. job=%s\n", job.name.c_str());
165 return true;
166 }
167
168 /* Copy output data */
169 if (job.output.size() > 0) {
170 if (interpreter.outputs_size() != job.output.size()) {
171 printf("Number of job and network outputs do not match. job=%zu, network=%u\n",
172 job.output.size(), interpreter.outputs_size());
173 return true;
174 }
175
176 for (unsigned i = 0; i < interpreter.outputs_size(); ++i) {
177 if (copyOutput(*interpreter.output(i), job.output[i])) {
178 return true;
179 }
180 }
181 }
182
183 if (job.expectedOutput.size() > 0) {
184 if (job.expectedOutput.size() != interpreter.outputs_size()) {
185 printf("Number of job and network expected outputs do not match. job=%zu, network=%zu\n",
186 job.expectedOutput.size(), interpreter.outputs_size());
187 return true;
188 }
189
190 for (unsigned int i = 0; i < interpreter.outputs_size(); i++) {
191 const DataPtr &expected = job.expectedOutput[i];
192 const TfLiteTensor *output = interpreter.output(i);
193
194 if (expected.size != output->bytes) {
195 printf("Expected output tensor size mismatch. index=%u, expected=%zu, network=%zu\n",
196 i, expected.size, output->bytes);
197 return true;
198 }
199
200 for (unsigned int j = 0; j < output->bytes; ++j) {
201 if (output->data.uint8[j] !=
202 static_cast<uint8_t *>(expected.data)[j]) {
203 printf("Expected output tensor data mismatch. index=%u, offset=%u, expected=%02x, network=%02x\n",
204 i, j, static_cast<uint8_t *>(expected.data)[j],
205 output->data.uint8[j]);
206 return true;
207 }
208 }
209 }
210 }
211
212 return false;
213 }
214
215 } /* namespace InferenceProcess */
216