1 /*
2 * Copyright (c) 2021, Commonwealth Scientific and Industrial Research
3 * Organisation (CSIRO) ABN 41 687 119 230.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * This is not exhaustive functional testing of the CMSIS-NN library.
8 *
9 * Individual tests have been pulled from CMSIS/NN/Tests/UnitTest to
10 * validate the integration of CMSIS-NN and Zephyr
11 */
12
13 #include <zephyr/ztest.h>
14 #include <zephyr/kernel.h>
15 #include <stdlib.h>
16
17 #include "arm_nnfunctions.h"
18
19 #define REPEAT_NUM 3
20
21 #define AVGPOOLING_2_OUT_CH 5
22 #define AVGPOOLING_2_IN_CH 5
23 #define AVGPOOLING_2_INPUT_W 12
24 #define AVGPOOLING_2_INPUT_H 1
25 #define AVGPOOLING_2_DST_SIZE 60
26 #define AVGPOOLING_2_INPUT_SIZE 60
27 #define AVGPOOLING_2_OUT_ACTIVATION_MIN -128
28 #define AVGPOOLING_2_OUT_ACTIVATION_MAX 127
29 #define AVGPOOLING_2_INPUT_BATCHES 1
30 #define AVGPOOLING_2_FILTER_X 3
31 #define AVGPOOLING_2_FILTER_Y 1
32 #define AVGPOOLING_2_STRIDE_X 1
33 #define AVGPOOLING_2_STRIDE_Y 2
34 #define AVGPOOLING_2_PAD_X 1
35 #define AVGPOOLING_2_PAD_Y 0
36 #define AVGPOOLING_2_OUTPUT_W 12
37 #define AVGPOOLING_2_OUTPUT_H 1
38
39 const int8_t avgpooling_2_input[60] = {
40 -82, -104, 10, -28, -52, -51, -66, 52, 124, -74, -21, 4, 37, -7, -33,
41 102, 110, 24, 52, 121, 13, -55, -79, -92, -35, -103, 86, 95, 46, 32,
42 -24, -123, 120, 29, -77, -97, -69, -68, 58, 38, 3, 3, 79, -47, 112,
43 -52, -113, -46, 107, 68, 83, -70, 91, 14, 113, 74, 73, -103, -98, 25};
44
45 const int8_t avgpooling_2_output_ref[60] = {
46 -67, -85, 31, 48, -63, -51, -55, 33, 30, -53, 10, 16, 38, 56, 5,
47 31, 20, -6, -16, 18, 4, 47, 13, 2, 39, -38, -31, 45, -6, -27,
48 -75, -35, 49, 44, -2, -39, -63, 44, 13, 24, -49, -60, -12, 39, 73,
49 11, -60, 41, 25, 98, 35, -37, -19, 8, 69, 79, 2, -6, -42, 69};
50
ZTEST(cmsis_nn,test_avgpool)51 ZTEST(cmsis_nn, test_avgpool)
52 {
53 int8_t output[AVGPOOLING_2_DST_SIZE] = {0};
54
55 cmsis_nn_context ctx;
56 cmsis_nn_pool_params pool_params;
57 cmsis_nn_dims input_dims;
58 cmsis_nn_dims filter_dims;
59 cmsis_nn_dims output_dims;
60
61 input_dims.n = AVGPOOLING_2_INPUT_BATCHES;
62 input_dims.w = AVGPOOLING_2_INPUT_W;
63 input_dims.h = AVGPOOLING_2_INPUT_H;
64 input_dims.c = AVGPOOLING_2_IN_CH;
65 filter_dims.w = AVGPOOLING_2_FILTER_X;
66 filter_dims.h = AVGPOOLING_2_FILTER_Y;
67 output_dims.w = AVGPOOLING_2_OUTPUT_W;
68 output_dims.h = AVGPOOLING_2_OUTPUT_H;
69 output_dims.c = AVGPOOLING_2_OUT_CH;
70
71 pool_params.padding.w = AVGPOOLING_2_PAD_X;
72 pool_params.padding.h = AVGPOOLING_2_PAD_Y;
73 pool_params.stride.w = AVGPOOLING_2_STRIDE_X;
74 pool_params.stride.h = AVGPOOLING_2_STRIDE_Y;
75
76 pool_params.activation.min = AVGPOOLING_2_OUT_ACTIVATION_MIN;
77 pool_params.activation.max = AVGPOOLING_2_OUT_ACTIVATION_MAX;
78
79 ctx.size = arm_avgpool_s8_get_buffer_size(AVGPOOLING_2_OUTPUT_W, AVGPOOLING_2_IN_CH);
80 ctx.buf = malloc(ctx.size);
81
82 arm_cmsis_nn_status result = arm_avgpool_s8(&ctx,
83 &pool_params,
84 &input_dims,
85 avgpooling_2_input,
86 &filter_dims,
87 &output_dims,
88 output);
89
90 free(ctx.buf);
91
92 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
93 zassert_mem_equal(avgpooling_2_output_ref, output, sizeof(output), "");
94 }
95
96 #define CONV_4_OUT_CH 3
97 #define CONV_4_IN_CH 3
98 #define CONV_4_INPUT_W 5
99 #define CONV_4_INPUT_H 5
100 #define CONV_4_DST_SIZE 36
101 #define CONV_4_INPUT_SIZE 75
102 #define CONV_4_OUT_ACTIVATION_MIN -109
103 #define CONV_4_OUT_ACTIVATION_MAX 127
104 #define CONV_4_INPUT_BATCHES 3
105 #define CONV_4_FILTER_X 2
106 #define CONV_4_FILTER_Y 3
107 #define CONV_4_STRIDE_X 2
108 #define CONV_4_STRIDE_Y 2
109 #define CONV_4_PAD_X 0
110 #define CONV_4_PAD_Y 0
111 #define CONV_4_OUTPUT_W 2
112 #define CONV_4_OUTPUT_H 2
113 #define CONV_4_INPUT_OFFSET 128
114 #define CONV_4_OUTPUT_OFFSET -128
115 #define CONV_4_DILATION_X 1
116 #define CONV_4_DILATION_Y 1
117
118 const int32_t conv_4_biases[3] = {13175, 9050, 18215};
119
120 const int8_t conv_4_weights[54] = {
121 -25, -83, -74, 105, 30, 118, -32, 127, 34, 127, -112, 39, -43, 104, 41, -124, 115, 5,
122 42, -48, -119, 93, 17, 57, 41, -41, -42, 23, 127, 18, 70, -99, 71, 67, 83, 76,
123 -50, 98, 66, 64, 127, -6, -77, -48, -26, 45, 77, 1, 81, 27, 124, -103, 37, 36};
124
125 const int8_t conv_4_input[225] = {
126 82, 120, -97, -44, -118, 73, 4, -84, -53, -122, -15, 77, 83, 43, 37,
127 85, -11, 103, 45, -69, -12, -8, 21, 6, -68, -83, -15, -99, 90, -62,
128 95, 62, -38, -32, -35, -105, -53, 70, 112, 14, -4, -33, -26, -93, -98,
129 22, -5, 22, -104, 57, -92, 30, -62, 0, -43, -82, 60, 99, -83, 32,
130 94, 49, 10, 112, -71, -27, -91, -79, 52, -92, -71, 86, -79, -15, -80,
131 -74, -4, 76, -119, 91, -23, -12, -111, -72, 26, 11, 64, 116, 38, 99,
132 125, 17, 6, -4, 46, 119, 113, -116, -125, 80, -57, 122, 75, 119, -117,
133 87, -121, -70, -75, -127, 16, -124, -110, 10, 71, 29, 27, 37, -24, 52,
134 28, -100, 86, -75, 117, -31, -115, -86, -122, 121, -96, -118, 32, 111, 25,
135 -90, -8, 110, 37, 35, 124, -123, 94, -122, -114, 37, 85, -36, 53, -40,
136 73, -99, 27, 10, 37, 41, 64, -97, -123, 75, 0, -107, -72, 58, -100,
137 17, 77, 114, 120, -83, -96, 75, -12, -27, 3, 35, 85, 4, 119, -20,
138 28, 99, 104, -78, -51, -82, -92, -40, -116, 35, -107, 39, 9, -120, -50,
139 -102, -114, 25, -77, 25, 7, 64, 110, 80, -93, -20, 34, 115, 75, 37,
140 47, 16, 6, -92, -25, 37, 69, 82, -61, -100, -85, -51, 6, -95, 58
141 };
142
143 const int32_t conv_4_output_mult[3] = {2039209398, 2005068758, 2023002003};
144
145 const int32_t conv_4_output_shift[3] = {-9, -9, -9};
146
147 const int8_t conv_4_output_ref[36] = {-5, -39, -31, 20, -37, -26, -109, -7, -10, -51, -58, 48,
148 -100, -32, 24, 4, 69, -38, -64, 65, -34, 95, -55, 39,
149 95, -54, 27, -49, 25, -68, -109, -66, 72, 38, -44, -40};
150
ZTEST(cmsis_nn,test_convolve)151 ZTEST(cmsis_nn, test_convolve)
152 {
153 int8_t output[CONV_4_DST_SIZE] = {0};
154
155 cmsis_nn_context ctx;
156 cmsis_nn_conv_params conv_params;
157 cmsis_nn_per_channel_quant_params quant_params;
158 cmsis_nn_dims input_dims;
159 cmsis_nn_dims filter_dims;
160 cmsis_nn_dims bias_dims;
161 cmsis_nn_dims output_dims;
162
163 const int32_t *bias_data = conv_4_biases;
164 const int8_t *kernel_data = conv_4_weights;
165 const int8_t *input_data = conv_4_input;
166
167 input_dims.n = CONV_4_INPUT_BATCHES;
168 input_dims.w = CONV_4_INPUT_W;
169 input_dims.h = CONV_4_INPUT_H;
170 input_dims.c = CONV_4_IN_CH;
171 filter_dims.w = CONV_4_FILTER_X;
172 filter_dims.h = CONV_4_FILTER_Y;
173 output_dims.w = CONV_4_OUTPUT_W;
174 output_dims.h = CONV_4_OUTPUT_H;
175 output_dims.c = CONV_4_OUT_CH;
176
177 conv_params.padding.w = CONV_4_PAD_X;
178 conv_params.padding.h = CONV_4_PAD_Y;
179 conv_params.stride.w = CONV_4_STRIDE_X;
180 conv_params.stride.h = CONV_4_STRIDE_Y;
181 conv_params.dilation.w = CONV_4_DILATION_X;
182 conv_params.dilation.h = CONV_4_DILATION_Y;
183
184 conv_params.input_offset = CONV_4_INPUT_OFFSET;
185 conv_params.output_offset = CONV_4_OUTPUT_OFFSET;
186 conv_params.activation.min = CONV_4_OUT_ACTIVATION_MIN;
187 conv_params.activation.max = CONV_4_OUT_ACTIVATION_MAX;
188 quant_params.multiplier = (int32_t *)conv_4_output_mult;
189 quant_params.shift = (int32_t *)conv_4_output_shift;
190
191 int32_t buf_size = arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims);
192
193 ctx.buf = malloc(buf_size);
194 ctx.size = 0;
195
196 arm_cmsis_nn_status result = arm_convolve_s8(&ctx,
197 &conv_params,
198 &quant_params,
199 &input_dims,
200 input_data,
201 &filter_dims,
202 kernel_data,
203 &bias_dims,
204 bias_data,
205 &output_dims,
206 output);
207
208 free(ctx.buf);
209 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
210 zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
211
212 buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
213 &filter_dims, &output_dims);
214 ctx.buf = malloc(buf_size);
215 ctx.size = 0;
216
217 result = arm_convolve_wrapper_s8(&ctx,
218 &conv_params,
219 &quant_params,
220 &input_dims,
221 input_data,
222 &filter_dims,
223 kernel_data,
224 &bias_dims,
225 bias_data,
226 &output_dims,
227 output);
228
229 free(ctx.buf);
230 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
231 zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
232 }
233
234 #define STRIDE2PAD1_OUT_CH 1
235 #define STRIDE2PAD1_IN_CH 1
236 #define STRIDE2PAD1_INPUT_W 7
237 #define STRIDE2PAD1_INPUT_H 7
238 #define STRIDE2PAD1_DST_SIZE 16
239 #define STRIDE2PAD1_INPUT_SIZE 49
240 #define STRIDE2PAD1_OUT_ACTIVATION_MIN -128
241 #define STRIDE2PAD1_OUT_ACTIVATION_MAX 127
242 #define STRIDE2PAD1_INPUT_BATCHES 1
243 #define STRIDE2PAD1_FILTER_X 3
244 #define STRIDE2PAD1_FILTER_Y 3
245 #define STRIDE2PAD1_STRIDE_X 2
246 #define STRIDE2PAD1_STRIDE_Y 2
247 #define STRIDE2PAD1_PAD_X 1
248 #define STRIDE2PAD1_PAD_Y 1
249 #define STRIDE2PAD1_OUTPUT_W 4
250 #define STRIDE2PAD1_OUTPUT_H 4
251 #define STRIDE2PAD1_INPUT_OFFSET 128
252 #define STRIDE2PAD1_OUTPUT_OFFSET -20
253 #define STRIDE2PAD1_DILATION_X 1
254 #define STRIDE2PAD1_DILATION_Y 1
255
256 const int32_t stride2pad1_biases[1] = {-9794};
257
258 const int8_t stride2pad1_weights[9] = {-54, 57, -19, -127, 87, 70, 74, -110, 66};
259
260 const int8_t stride2pad1_input[49] = {
261 -91, -30, -57, -76, 32, -13, 14, -96, 108, -4, 41, 48, 107, -68, -101, 30, 95,
262 95, 91, -66, -80, 114, -49, 7, -67, -35, -1, -88, -77, -56, -103, 5, -39, -118,
263 -24, -32, 67, 11, 38, -16, -124, 44, -46, -92, -24, 108, 80, -29, -3};
264
265 const int32_t stride2pad1_output_mult[1] = {2033801520};
266
267 const int32_t stride2pad1_output_shift[1] = {-8};
268
269 const int8_t stride2pad1_output_ref[16] = {26, -11, 33, -25, -96, -52, -78, -86,
270 33, -2, -88, -113, -14, 0, -84, -27};
271
ZTEST(cmsis_nn,test_depthwise_convolve)272 ZTEST(cmsis_nn, test_depthwise_convolve)
273 {
274 int8_t output[STRIDE2PAD1_DST_SIZE] = {0};
275
276 cmsis_nn_context ctx;
277 cmsis_nn_dw_conv_params dw_conv_params;
278 cmsis_nn_per_channel_quant_params quant_params;
279 cmsis_nn_dims input_dims;
280 cmsis_nn_dims filter_dims;
281 cmsis_nn_dims bias_dims = {0};
282 cmsis_nn_dims output_dims;
283
284 const int32_t *bias_data = stride2pad1_biases;
285 const int8_t *kernel_data = stride2pad1_weights;
286 const int8_t *input_data = stride2pad1_input;
287
288 input_dims.n = STRIDE2PAD1_INPUT_BATCHES;
289 input_dims.w = STRIDE2PAD1_INPUT_W;
290 input_dims.h = STRIDE2PAD1_INPUT_H;
291 input_dims.c = STRIDE2PAD1_IN_CH;
292 filter_dims.w = STRIDE2PAD1_FILTER_X;
293 filter_dims.h = STRIDE2PAD1_FILTER_Y;
294 output_dims.w = STRIDE2PAD1_OUTPUT_W;
295 output_dims.h = STRIDE2PAD1_OUTPUT_H;
296 output_dims.c = STRIDE2PAD1_OUT_CH;
297
298 dw_conv_params.padding.w = STRIDE2PAD1_PAD_X;
299 dw_conv_params.padding.h = STRIDE2PAD1_PAD_Y;
300 dw_conv_params.stride.w = STRIDE2PAD1_STRIDE_X;
301 dw_conv_params.stride.h = STRIDE2PAD1_STRIDE_Y;
302 dw_conv_params.dilation.w = STRIDE2PAD1_DILATION_X;
303 dw_conv_params.dilation.h = STRIDE2PAD1_DILATION_Y;
304
305 dw_conv_params.ch_mult = 1;
306
307 dw_conv_params.input_offset = STRIDE2PAD1_INPUT_OFFSET;
308 dw_conv_params.output_offset = STRIDE2PAD1_OUTPUT_OFFSET;
309 dw_conv_params.activation.min = STRIDE2PAD1_OUT_ACTIVATION_MIN;
310 dw_conv_params.activation.max = STRIDE2PAD1_OUT_ACTIVATION_MAX;
311 quant_params.multiplier = (int32_t *)stride2pad1_output_mult;
312 quant_params.shift = (int32_t *)stride2pad1_output_shift;
313
314 ctx.buf = NULL;
315 ctx.size = 0;
316
317 arm_cmsis_nn_status result = arm_depthwise_conv_s8(&ctx,
318 &dw_conv_params,
319 &quant_params,
320 &input_dims,
321 input_data,
322 &filter_dims,
323 kernel_data,
324 &bias_dims,
325 bias_data,
326 &output_dims,
327 output);
328
329 free(ctx.buf);
330 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
331 zassert_mem_equal(stride2pad1_output_ref, output, sizeof(output), "");
332 }
333
334 #define FULLY_CONNECTED_MVE_0_OUT_CH 9
335 #define FULLY_CONNECTED_MVE_0_IN_CH 16
336 #define FULLY_CONNECTED_MVE_0_INPUT_W 1
337 #define FULLY_CONNECTED_MVE_0_INPUT_H 1
338 #define FULLY_CONNECTED_MVE_0_DST_SIZE 9
339 #define FULLY_CONNECTED_MVE_0_INPUT_SIZE 16
340 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN -128
341 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX 127
342 #define FULLY_CONNECTED_MVE_0_INPUT_BATCHES 1
343 #define FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER 1244038257
344 #define FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT -9
345 #define FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH 16
346 #define FULLY_CONNECTED_MVE_0_INPUT_OFFSET 128
347 #define FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET -26
348
349 const int32_t fully_connected_mve_0_biases[9] = {11295, -30752, -3196, 10489, -5120,
350 18598, 27393, 29746, 22967};
351
352 const int8_t fully_connected_mve_0_input[16] = {-43, 68, 79, -12, -119, -56, -102, -46,
353 107, -65, -109, -7, 92, -99, -80, -29};
354
355 const int8_t fully_connected_mve_0_output_ref[9] = {-9, -3, 26, 8, 3, -88, 75, 34, 5};
356
357 const int8_t fully_connected_mve_0_weights[144] = {
358 37, -46, 75, -33, -52, -82, -94, 64, 71, 65, 64, 16, -66, -5, -65, -44,
359 82, 42, 84, 105, 18, 79, -103, -75, -95, 65, 87, 103, 43, -25, -66, 75,
360 125, 40, -34, 24, 9, -79, 4, 73, 98, -75, 42, 81, 18, -58, -119, 92,
361 0, -72, 48, 23, -69, 11, -95, -103, 66, 117, 107, -96, 114, -29, 75, -93,
362 118, 66, -19, 83, -14, 86, -110, 44, 37, -9, 17, -107, 50, -116, -116, -27,
363 -84, -126, -108, -127, -71, 8, 81, 108, -61, 126, 69, -45, 37, -78, -102, -55,
364 116, 112, -111, -89, -57, 82, -47, 22, 125, -84, 97, -9, 88, 74, -15, 118,
365 -95, 112, 89, 44, -17, -112, -71, -94, 1, -117, 112, -92, 52, 57, -22, 80,
366 -60, 95, -106, -1, -27, 105, 6, 123, 6, 96, 126, -65, -29, 103, 19, -45};
367
ZTEST(cmsis_nn,test_fully_connected)368 ZTEST(cmsis_nn, test_fully_connected)
369 {
370 int8_t output[FULLY_CONNECTED_MVE_0_DST_SIZE] = {0};
371
372 cmsis_nn_context ctx;
373 cmsis_nn_fc_params fc_params;
374 cmsis_nn_per_tensor_quant_params quant_params;
375 cmsis_nn_dims input_dims;
376 cmsis_nn_dims filter_dims;
377 cmsis_nn_dims bias_dims;
378 cmsis_nn_dims output_dims;
379
380 const int32_t *bias_data = fully_connected_mve_0_biases;
381 const int8_t *kernel_data = fully_connected_mve_0_weights;
382 const int8_t *input_data = fully_connected_mve_0_input;
383
384 input_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
385 input_dims.w = FULLY_CONNECTED_MVE_0_INPUT_W;
386 input_dims.h = FULLY_CONNECTED_MVE_0_INPUT_H;
387 input_dims.c = FULLY_CONNECTED_MVE_0_IN_CH;
388 filter_dims.n = FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH;
389 filter_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
390 output_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
391 output_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
392
393 fc_params.input_offset = FULLY_CONNECTED_MVE_0_INPUT_OFFSET;
394 fc_params.filter_offset = 0;
395 fc_params.output_offset = FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET;
396 fc_params.activation.min = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN;
397 fc_params.activation.max = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX;
398
399 quant_params.multiplier = FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER;
400 quant_params.shift = FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT;
401
402 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
403
404 ctx.buf = malloc(buf_size);
405 ctx.size = buf_size;
406 arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
407 &fc_params,
408 &quant_params,
409 &input_dims,
410 input_data,
411 &filter_dims,
412 kernel_data,
413 &bias_dims,
414 bias_data,
415 &output_dims,
416 output);
417
418 free(ctx.buf);
419 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
420 zassert_mem_equal(fully_connected_mve_0_output_ref, output, sizeof(output), "");
421 }
422
423 #define MAXPOOLING_2_OUT_CH 5
424 #define MAXPOOLING_2_IN_CH 5
425 #define MAXPOOLING_2_INPUT_W 12
426 #define MAXPOOLING_2_INPUT_H 1
427 #define MAXPOOLING_2_DST_SIZE 60
428 #define MAXPOOLING_2_INPUT_SIZE 60
429 #define MAXPOOLING_2_OUT_ACTIVATION_MIN -128
430 #define MAXPOOLING_2_OUT_ACTIVATION_MAX 127
431 #define MAXPOOLING_2_INPUT_BATCHES 1
432 #define MAXPOOLING_2_FILTER_X 3
433 #define MAXPOOLING_2_FILTER_Y 1
434 #define MAXPOOLING_2_STRIDE_X 1
435 #define MAXPOOLING_2_STRIDE_Y 2
436 #define MAXPOOLING_2_PAD_X 1
437 #define MAXPOOLING_2_PAD_Y 0
438 #define MAXPOOLING_2_OUTPUT_W 12
439 #define MAXPOOLING_2_OUTPUT_H 1
440
441 const int8_t maxpooling_2_input[60] = {
442 75, -52, -42, -30, 56, 64, 106, -36, 120, -3, 34, -105, 69, 75, -39,
443 15, 93, -71, 39, 34, -11, 65, 22, 59, 106, 105, 45, -116, -75, 123,
444 -65, 75, -61, 13, -25, -123, 59, 110, -65, 86, -108, -107, -17, 38, 27,
445 -1, -115, -123, 75, -75, 68, 52, 12, -35, 116, -68, 22, 15, 76, -81};
446
447 const int8_t maxpooling_2_output_ref[60] = {
448 75, 106, -36, 120, 56, 75, 106, 69, 120, 56, 64, 106, 69, 120, 34,
449 34, 93, 69, 75, 106, 105, 93, 22, 59, 123, 105, 75, 22, 59, 123,
450 105, 75, 110, 13, 123, -65, 75, 110, 38, 86, -1, 59, 110, 75, 86,
451 68, 52, 12, 75, 116, 68, 52, 15, 76, 116, 68, 52, 15, 76, 116};
452
ZTEST(cmsis_nn,test_max_pool)453 ZTEST(cmsis_nn, test_max_pool)
454 {
455 int8_t output[MAXPOOLING_2_DST_SIZE] = {0};
456
457 cmsis_nn_context ctx;
458 cmsis_nn_pool_params pool_params;
459 cmsis_nn_dims input_dims;
460 cmsis_nn_dims filter_dims;
461 cmsis_nn_dims output_dims;
462
463 const int8_t *input_data = maxpooling_2_input;
464
465 input_dims.n = MAXPOOLING_2_INPUT_BATCHES;
466 input_dims.w = MAXPOOLING_2_INPUT_W;
467 input_dims.h = MAXPOOLING_2_INPUT_H;
468 input_dims.c = MAXPOOLING_2_IN_CH;
469 filter_dims.w = MAXPOOLING_2_FILTER_X;
470 filter_dims.h = MAXPOOLING_2_FILTER_Y;
471 output_dims.w = MAXPOOLING_2_OUTPUT_W;
472 output_dims.h = MAXPOOLING_2_OUTPUT_H;
473 output_dims.c = MAXPOOLING_2_OUT_CH;
474
475 pool_params.padding.w = MAXPOOLING_2_PAD_X;
476 pool_params.padding.h = MAXPOOLING_2_PAD_Y;
477 pool_params.stride.w = MAXPOOLING_2_STRIDE_X;
478 pool_params.stride.h = MAXPOOLING_2_STRIDE_Y;
479
480 pool_params.activation.min = MAXPOOLING_2_OUT_ACTIVATION_MIN;
481 pool_params.activation.max = MAXPOOLING_2_OUT_ACTIVATION_MAX;
482
483 for (int i = 0; i < REPEAT_NUM; i++) {
484 arm_cmsis_nn_status result =
485 arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
486 &output_dims, output);
487
488 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
489 zassert_mem_equal(maxpooling_2_output_ref, output, sizeof(output), "");
490 }
491 }
492
493 #define SOFTMAX_NUM_ROWS 2
494 #define SOFTMAX_ROW_SIZE 5
495 #define SOFTMAX_INPUT_MULT 1077952640
496 #define SOFTMAX_INPUT_LEFT_SHIFT 19
497 #define SOFTMAX_DIFF_MIN -3968
498 #define SOFTMAX_DST_SIZE 10
499
500 const int8_t softmax_input[10] = {101, 49, 6, -34, -75, -79, -38, 120, -55, 115};
501
502 const int8_t softmax_output_ref[10] = {-57, -70, -79, -86, -92, -94, -88, -54, -91, -56};
503
ZTEST(cmsis_nn,test_softmax)504 ZTEST(cmsis_nn, test_softmax)
505 {
506 const int32_t num_rows = SOFTMAX_NUM_ROWS;
507 const int32_t row_size = SOFTMAX_ROW_SIZE;
508 const int32_t mult = SOFTMAX_INPUT_MULT;
509 const int32_t shift = SOFTMAX_INPUT_LEFT_SHIFT;
510 const int32_t diff_min = SOFTMAX_DIFF_MIN;
511 const int8_t *input_data = softmax_input;
512 int8_t output[SOFTMAX_DST_SIZE];
513
514 for (int i = 0; i < REPEAT_NUM; i++) {
515 arm_softmax_s8(input_data, num_rows, row_size, mult, shift, diff_min, output);
516 zassert_mem_equal(softmax_output_ref, output, sizeof(output), "");
517 }
518 }
519
520 #define SVDF_2_MULTIPLIER_IN 1717987072
521 #define SVDF_2_MULTIPLIER_OUT 1099511552
522 #define SVDF_2_SHIFT_1 -3
523 #define SVDF_2_SHIFT_2 -11
524 #define SVDF_2_IN_ACTIVATION_MIN -32768
525 #define SVDF_2_IN_ACTIVATION_MAX 32767
526 #define SVDF_2_RANK 2
527 #define SVDF_2_FEATURE_BATCHES 10
528 #define SVDF_2_TIME_BATCHES 2
529 #define SVDF_2_INPUT_SIZE 7
530 #define SVDF_2_DST_SIZE 15
531 #define SVDF_2_OUT_ACTIVATION_MIN -128
532 #define SVDF_2_OUT_ACTIVATION_MAX 127
533 #define SVDF_2_INPUT_BATCHES 3
534 #define SVDF_2_INPUT_OFFSET 0
535 #define SVDF_2_OUTPUT_OFFSET 0
536
537 const int32_t svdf_2_biases[5] = {0, 0, 0, 0, 0};
538
539 const int16_t svdf_2_state[60] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
540 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
541 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
542
543 const int8_t svdf_2_weights_feature[70] = {
544 27, 82, -108, -127, 85, 3, -51, 32, 110, -6, -14, -16, 31, 101,
545 -122, 19, 76, 74, -80, 12, -22, -17, 10, -28, 55, 109, 2, -107,
546 -4, 72, -65, -59, 36, -69, 105, -97, 25, 38, 110, -121, -88, -126,
547 -14, 16, -88, -66, 3, -93, 69, -64, 44, 103, 95, -95, 68, -46,
548 106, -31, -63, 23, -38, 36, -95, -43, 93, 77, 91, -26, 33, 59};
549
550 const int16_t svdf_2_weights_time[20] = {-31, -88, -10, -72, -119, -6, -70, 63, -10, 93,
551 5, 42, -6, 22, 6, 51, 37, -38, 5, 117};
552
553 const int8_t svdf_2_input_sequence[42] = {
554 29, 81, -38, 17, -116, 43, 119, -127, 74, 115, 9, 118, 7, -56,
555 -53, -14, -98, 60, -128, 10, 28, -18, 12, -28, -126, 87, -115, -44,
556 -123, -109, -59, -87, -69, 121, -128, -95, -70, 2, 81, -119, 84, -122};
557
558 const int8_t svdf_2_output_ref[15] = {-53, 45, 27, -24, -53, 26, -82, -38,
559 11, -85, 94, -16, -32, 31, 4};
560
check_null_bias(const int32_t * bias,int32_t size)561 static bool check_null_bias(const int32_t *bias, int32_t size)
562 {
563 bool null_bias = true;
564
565 for (int i = 0; i < size; i++) {
566 if (bias[i] != 0) {
567 null_bias = false;
568 break;
569 }
570 }
571 return null_bias;
572 }
573
ZTEST(cmsis_nn,test_svdf)574 ZTEST(cmsis_nn, test_svdf)
575 {
576 cmsis_nn_context input_ctx;
577 cmsis_nn_context output_ctx;
578 cmsis_nn_svdf_params svdf_2_params;
579 cmsis_nn_dims input_dims;
580 cmsis_nn_dims weights_feature_dims;
581 cmsis_nn_dims weights_time_dims;
582 cmsis_nn_dims state_dims;
583 cmsis_nn_dims output_dims;
584 cmsis_nn_dims bias_dims;
585 cmsis_nn_per_tensor_quant_params input_quant_params;
586 cmsis_nn_per_tensor_quant_params output_quant_params;
587 int8_t output_data[SVDF_2_DST_SIZE];
588
589 const int8_t *weights_feature_data = svdf_2_weights_feature;
590 const int16_t *weights_time_data = svdf_2_weights_time;
591
592 input_dims.n = SVDF_2_INPUT_BATCHES;
593 input_dims.h = SVDF_2_INPUT_SIZE;
594 weights_feature_dims.n = SVDF_2_FEATURE_BATCHES;
595 weights_time_dims.h = SVDF_2_TIME_BATCHES;
596
597 input_quant_params.multiplier = SVDF_2_MULTIPLIER_IN;
598 input_quant_params.shift = SVDF_2_SHIFT_1;
599 output_quant_params.multiplier = SVDF_2_MULTIPLIER_OUT;
600 output_quant_params.shift = SVDF_2_SHIFT_2;
601
602 svdf_2_params.input_activation.min = SVDF_2_IN_ACTIVATION_MIN;
603 svdf_2_params.input_activation.max = SVDF_2_IN_ACTIVATION_MAX;
604 svdf_2_params.output_activation.min = SVDF_2_OUT_ACTIVATION_MIN;
605 svdf_2_params.output_activation.max = SVDF_2_OUT_ACTIVATION_MAX;
606 svdf_2_params.input_offset = SVDF_2_INPUT_OFFSET;
607 svdf_2_params.output_offset = SVDF_2_OUTPUT_OFFSET;
608 svdf_2_params.rank = SVDF_2_RANK;
609
610 const int input_round_size = SVDF_2_INPUT_BATCHES * SVDF_2_INPUT_SIZE;
611 const int number_inputs = sizeof(svdf_2_input_sequence) / input_round_size;
612 const int32_t number_units = SVDF_2_FEATURE_BATCHES / SVDF_2_RANK;
613 const int scratch_size = SVDF_2_INPUT_BATCHES * SVDF_2_FEATURE_BATCHES * sizeof(int32_t);
614 const int scratch_size_out = SVDF_2_INPUT_BATCHES * number_units * sizeof(int32_t);
615
616 input_ctx.buf = malloc(scratch_size);
617 output_ctx.buf = malloc(scratch_size_out);
618
619 int8_t *input_data = malloc(input_round_size);
620 int16_t *state_data = malloc(sizeof(svdf_2_state));
621 const bool null_bias = check_null_bias(svdf_2_biases,
622 SVDF_2_DST_SIZE / SVDF_2_INPUT_BATCHES);
623
624 for (int i = 0; i < REPEAT_NUM; i++) {
625 memcpy(state_data, svdf_2_state, sizeof(svdf_2_state));
626 for (int j = 0; j < number_inputs; j++) {
627 memcpy(input_data, svdf_2_input_sequence + j * input_round_size,
628 input_round_size);
629 arm_cmsis_nn_status result = arm_svdf_state_s16_s8(&input_ctx,
630 &output_ctx,
631 &svdf_2_params,
632 &input_quant_params,
633 &output_quant_params,
634 &input_dims,
635 input_data,
636 &state_dims,
637 state_data,
638 &weights_feature_dims,
639 weights_feature_data,
640 &weights_time_dims,
641 weights_time_data,
642 &bias_dims,
643 null_bias == true ? NULL : svdf_2_biases,
644 &output_dims,
645 output_data);
646 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
647 }
648
649 zassert_mem_equal(svdf_2_output_ref, output_data, sizeof(output_data), "");
650 }
651 free(state_data);
652 free(input_data);
653 free(input_ctx.buf);
654 free(output_ctx.buf);
655 }
656
657 ZTEST_SUITE(cmsis_nn, NULL, NULL, NULL, NULL, NULL);
658