1 /*
2  * Copyright 2019-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "inference_process.hpp"
8 
9 #include <tensorflow/lite/micro/all_ops_resolver.h>
10 #include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
11 #include <tensorflow/lite/micro/micro_error_reporter.h>
12 #include <tensorflow/lite/micro/micro_interpreter.h>
13 #include <tensorflow/lite/micro/micro_profiler.h>
14 #include <tensorflow/lite/schema/schema_generated.h>
15 
16 #include <cmsis_compiler.h>
17 #include <inttypes.h>
18 #include <zephyr/kernel.h>
19 
20 using namespace std;
21 
22 namespace
23 {
copyOutput(const TfLiteTensor & src,InferenceProcess::DataPtr & dst)24 bool copyOutput(const TfLiteTensor &src, InferenceProcess::DataPtr &dst)
25 {
26 	if (dst.data == nullptr) {
27 		return false;
28 	}
29 
30 	if (src.bytes > dst.size) {
31 		printk("Tensor size mismatch (bytes): actual=%d, expected%d.\n", src.bytes,
32 		       dst.size);
33 		return true;
34 	}
35 
36 	copy(src.data.uint8, src.data.uint8 + src.bytes, static_cast<uint8_t *>(dst.data));
37 	dst.size = src.bytes;
38 
39 	return false;
40 }
41 
42 } /* namespace */
43 
44 namespace InferenceProcess
45 {
DataPtr(void * _data,size_t _size)46 DataPtr::DataPtr(void *_data, size_t _size) : data(_data), size(_size)
47 {
48 }
49 
invalidate()50 void DataPtr::invalidate()
51 {
52 #if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
53 	SCB_InvalidateDCache_by_Addr(reinterpret_cast<uint32_t *>(data), size);
54 #endif
55 }
56 
clean()57 void DataPtr::clean()
58 {
59 #if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
60 	SCB_CleanDCache_by_Addr(reinterpret_cast<uint32_t *>(data), size);
61 #endif
62 }
63 
InferenceJob()64 InferenceJob::InferenceJob()
65 {
66 }
67 
InferenceJob(const string & _name,const DataPtr & _networkModel,const vector<DataPtr> & _input,const vector<DataPtr> & _output,const vector<DataPtr> & _expectedOutput)68 InferenceJob::InferenceJob(const string &_name, const DataPtr &_networkModel,
69 			   const vector<DataPtr> &_input, const vector<DataPtr> &_output,
70 			   const vector<DataPtr> &_expectedOutput)
71 	: name(_name), networkModel(_networkModel), input(_input), output(_output),
72 	  expectedOutput(_expectedOutput)
73 {
74 }
75 
invalidate()76 void InferenceJob::invalidate()
77 {
78 	networkModel.invalidate();
79 
80 	for (auto &it : input) {
81 		it.invalidate();
82 	}
83 
84 	for (auto &it : output) {
85 		it.invalidate();
86 	}
87 
88 	for (auto &it : expectedOutput) {
89 		it.invalidate();
90 	}
91 }
92 
clean()93 void InferenceJob::clean()
94 {
95 	networkModel.clean();
96 
97 	for (auto &it : input) {
98 		it.clean();
99 	}
100 
101 	for (auto &it : output) {
102 		it.clean();
103 	}
104 
105 	for (auto &it : expectedOutput) {
106 		it.clean();
107 	}
108 }
109 
runJob(InferenceJob & job)110 bool InferenceProcess::runJob(InferenceJob &job)
111 {
112 	/* Get model handle and verify that the version is correct */
113 	const tflite::Model *model = ::tflite::GetModel(job.networkModel.data);
114 	if (model->version() != TFLITE_SCHEMA_VERSION) {
115 		printk("Model schema version unsupported: version=%" PRIu32 ", supported=%d.\n",
116 		       model->version(), TFLITE_SCHEMA_VERSION);
117 		return true;
118 	}
119 
120 	/* Create the TFL micro interpreter */
121 	tflite::AllOpsResolver resolver;
122 	tflite::MicroErrorReporter errorReporter;
123 
124 	tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize,
125 					     &errorReporter);
126 
127 	/* Allocate tensors */
128 	TfLiteStatus allocate_status = interpreter.AllocateTensors();
129 	if (allocate_status != kTfLiteOk) {
130 		printk("Failed to allocate tensors for inference. job=%p\n", &job);
131 		return true;
132 	}
133 
134 	if (job.input.size() != interpreter.inputs_size()) {
135 		printk("Number of job and network inputs do not match. input=%zu, network=%zu\n",
136 		       job.input.size(), interpreter.inputs_size());
137 		return true;
138 	}
139 
140 	/* Copy input data */
141 	for (size_t i = 0; i < interpreter.inputs_size(); ++i) {
142 		const DataPtr &input = job.input[i];
143 		const TfLiteTensor *tensor = interpreter.input(i);
144 
145 		if (input.size != tensor->bytes) {
146 			printk("Input tensor size mismatch. index=%zu, input=%zu, network=%u\n", i,
147 			       input.size, tensor->bytes);
148 			return true;
149 		}
150 
151 		copy(static_cast<char *>(input.data), static_cast<char *>(input.data) + input.size,
152 		     tensor->data.uint8);
153 	}
154 
155 	/* Run the inference */
156 	TfLiteStatus invoke_status = interpreter.Invoke();
157 	if (invoke_status != kTfLiteOk) {
158 		printk("Invoke failed for inference. job=%s\n", job.name.c_str());
159 		return true;
160 	}
161 
162 	/* Copy output data */
163 	if (job.output.size() > 0) {
164 		if (interpreter.outputs_size() != job.output.size()) {
165 			printk("Number of job and network outputs do not match. job=%zu, network=%u\n",
166 			       job.output.size(), interpreter.outputs_size());
167 			return true;
168 		}
169 
170 		for (unsigned i = 0; i < interpreter.outputs_size(); ++i) {
171 			if (copyOutput(*interpreter.output(i), job.output[i])) {
172 				return true;
173 			}
174 		}
175 	}
176 
177 	if (job.expectedOutput.size() > 0) {
178 		if (job.expectedOutput.size() != interpreter.outputs_size()) {
179 			printk("Number of job and network expected outputs do not match. job=%zu, network=%zu\n",
180 			       job.expectedOutput.size(), interpreter.outputs_size());
181 			return true;
182 		}
183 
184 		for (unsigned int i = 0; i < interpreter.outputs_size(); i++) {
185 			const DataPtr &expected = job.expectedOutput[i];
186 			const TfLiteTensor *output = interpreter.output(i);
187 
188 			if (expected.size != output->bytes) {
189 				printk("Expected output tensor size mismatch. index=%u, expected=%zu, network=%zu\n",
190 				       i, expected.size, output->bytes);
191 				return true;
192 			}
193 
194 			for (unsigned int j = 0; j < output->bytes; ++j) {
195 				if (output->data.uint8[j] !=
196 				    static_cast<uint8_t *>(expected.data)[j]) {
197 					printk("Expected output tensor data mismatch. index=%u, offset=%u, expected=%02x, network=%02x\n",
198 					       i, j, static_cast<uint8_t *>(expected.data)[j],
199 					       output->data.uint8[j]);
200 					return true;
201 				}
202 			}
203 		}
204 	}
205 
206 	return false;
207 }
208 
209 } /* namespace InferenceProcess */
210