1 /*
2 * Copyright (c) 2021, Commonwealth Scientific and Industrial Research
3 * Organisation (CSIRO) ABN 41 687 119 230.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * This is not exhaustive functional testing of the CMSIS-NN library.
8 *
9 * Individual tests have been pulled from CMSIS/NN/Tests/UnitTest to
10 * validate the integration of CMSIS-NN and Zephyr
11 */
12
13 #include <zephyr/ztest.h>
14 #include <zephyr/kernel.h>
15 #include <stdlib.h>
16
17 #include "arm_nnfunctions.h"
18
19 #define REPEAT_NUM 3
20
21 #define AVGPOOLING_2_OUT_CH 5
22 #define AVGPOOLING_2_IN_CH 5
23 #define AVGPOOLING_2_INPUT_W 12
24 #define AVGPOOLING_2_INPUT_H 1
25 #define AVGPOOLING_2_DST_SIZE 60
26 #define AVGPOOLING_2_INPUT_SIZE 60
27 #define AVGPOOLING_2_OUT_ACTIVATION_MIN -128
28 #define AVGPOOLING_2_OUT_ACTIVATION_MAX 127
29 #define AVGPOOLING_2_INPUT_BATCHES 1
30 #define AVGPOOLING_2_FILTER_X 3
31 #define AVGPOOLING_2_FILTER_Y 1
32 #define AVGPOOLING_2_STRIDE_X 1
33 #define AVGPOOLING_2_STRIDE_Y 2
34 #define AVGPOOLING_2_PAD_X 1
35 #define AVGPOOLING_2_PAD_Y 0
36 #define AVGPOOLING_2_OUTPUT_W 12
37 #define AVGPOOLING_2_OUTPUT_H 1
38
39 const int8_t avgpooling_2_input[60] = {
40 -82, -104, 10, -28, -52, -51, -66, 52, 124, -74, -21, 4, 37, -7, -33,
41 102, 110, 24, 52, 121, 13, -55, -79, -92, -35, -103, 86, 95, 46, 32,
42 -24, -123, 120, 29, -77, -97, -69, -68, 58, 38, 3, 3, 79, -47, 112,
43 -52, -113, -46, 107, 68, 83, -70, 91, 14, 113, 74, 73, -103, -98, 25};
44
45 const int8_t avgpooling_2_output_ref[60] = {
46 -67, -85, 31, 48, -63, -51, -55, 33, 30, -53, 10, 16, 38, 56, 5,
47 31, 20, -6, -16, 18, 4, 47, 13, 2, 39, -38, -31, 45, -6, -27,
48 -75, -35, 49, 44, -2, -39, -63, 44, 13, 24, -49, -60, -12, 39, 73,
49 11, -60, 41, 25, 98, 35, -37, -19, 8, 69, 79, 2, -6, -42, 69};
50
ZTEST(cmsis_nn,test_avgpool)51 ZTEST(cmsis_nn, test_avgpool)
52 {
53 int8_t output[AVGPOOLING_2_DST_SIZE] = {0};
54
55 cmsis_nn_context ctx;
56 cmsis_nn_pool_params pool_params;
57 cmsis_nn_dims input_dims;
58 cmsis_nn_dims filter_dims;
59 cmsis_nn_dims output_dims;
60
61 input_dims.n = AVGPOOLING_2_INPUT_BATCHES;
62 input_dims.w = AVGPOOLING_2_INPUT_W;
63 input_dims.h = AVGPOOLING_2_INPUT_H;
64 input_dims.c = AVGPOOLING_2_IN_CH;
65 filter_dims.w = AVGPOOLING_2_FILTER_X;
66 filter_dims.h = AVGPOOLING_2_FILTER_Y;
67 output_dims.w = AVGPOOLING_2_OUTPUT_W;
68 output_dims.h = AVGPOOLING_2_OUTPUT_H;
69 output_dims.c = AVGPOOLING_2_OUT_CH;
70
71 pool_params.padding.w = AVGPOOLING_2_PAD_X;
72 pool_params.padding.h = AVGPOOLING_2_PAD_Y;
73 pool_params.stride.w = AVGPOOLING_2_STRIDE_X;
74 pool_params.stride.h = AVGPOOLING_2_STRIDE_Y;
75
76 pool_params.activation.min = AVGPOOLING_2_OUT_ACTIVATION_MIN;
77 pool_params.activation.max = AVGPOOLING_2_OUT_ACTIVATION_MAX;
78
79 ctx.size = arm_avgpool_s8_get_buffer_size(AVGPOOLING_2_OUTPUT_W, AVGPOOLING_2_IN_CH);
80 ctx.buf = malloc(ctx.size);
81
82 arm_cmsis_nn_status result = arm_avgpool_s8(&ctx,
83 &pool_params,
84 &input_dims,
85 avgpooling_2_input,
86 &filter_dims,
87 &output_dims,
88 output);
89
90 free(ctx.buf);
91
92 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
93 zassert_mem_equal(avgpooling_2_output_ref, output, sizeof(output), "");
94 }
95
96 #define CONV_4_OUT_CH 3
97 #define CONV_4_IN_CH 3
98 #define CONV_4_INPUT_W 5
99 #define CONV_4_INPUT_H 5
100 #define CONV_4_DST_SIZE 36
101 #define CONV_4_INPUT_SIZE 75
102 #define CONV_4_OUT_ACTIVATION_MIN -109
103 #define CONV_4_OUT_ACTIVATION_MAX 127
104 #define CONV_4_INPUT_BATCHES 3
105 #define CONV_4_FILTER_X 2
106 #define CONV_4_FILTER_Y 3
107 #define CONV_4_STRIDE_X 2
108 #define CONV_4_STRIDE_Y 2
109 #define CONV_4_PAD_X 0
110 #define CONV_4_PAD_Y 0
111 #define CONV_4_OUTPUT_W 2
112 #define CONV_4_OUTPUT_H 2
113 #define CONV_4_INPUT_OFFSET 128
114 #define CONV_4_OUTPUT_OFFSET -128
115 #define CONV_4_DILATION_X 1
116 #define CONV_4_DILATION_Y 1
117
118 const int32_t conv_4_biases[3] = {13175, 9050, 18215};
119
120 const int8_t conv_4_weights[54] = {
121 -25, -83, -74, 105, 30, 118, -32, 127, 34, 127, -112, 39, -43, 104, 41, -124, 115, 5,
122 42, -48, -119, 93, 17, 57, 41, -41, -42, 23, 127, 18, 70, -99, 71, 67, 83, 76,
123 -50, 98, 66, 64, 127, -6, -77, -48, -26, 45, 77, 1, 81, 27, 124, -103, 37, 36};
124
125 const int8_t conv_4_input[225] = {
126 82, 120, -97, -44, -118, 73, 4, -84, -53, -122, -15, 77, 83, 43, 37,
127 85, -11, 103, 45, -69, -12, -8, 21, 6, -68, -83, -15, -99, 90, -62,
128 95, 62, -38, -32, -35, -105, -53, 70, 112, 14, -4, -33, -26, -93, -98,
129 22, -5, 22, -104, 57, -92, 30, -62, 0, -43, -82, 60, 99, -83, 32,
130 94, 49, 10, 112, -71, -27, -91, -79, 52, -92, -71, 86, -79, -15, -80,
131 -74, -4, 76, -119, 91, -23, -12, -111, -72, 26, 11, 64, 116, 38, 99,
132 125, 17, 6, -4, 46, 119, 113, -116, -125, 80, -57, 122, 75, 119, -117,
133 87, -121, -70, -75, -127, 16, -124, -110, 10, 71, 29, 27, 37, -24, 52,
134 28, -100, 86, -75, 117, -31, -115, -86, -122, 121, -96, -118, 32, 111, 25,
135 -90, -8, 110, 37, 35, 124, -123, 94, -122, -114, 37, 85, -36, 53, -40,
136 73, -99, 27, 10, 37, 41, 64, -97, -123, 75, 0, -107, -72, 58, -100,
137 17, 77, 114, 120, -83, -96, 75, -12, -27, 3, 35, 85, 4, 119, -20,
138 28, 99, 104, -78, -51, -82, -92, -40, -116, 35, -107, 39, 9, -120, -50,
139 -102, -114, 25, -77, 25, 7, 64, 110, 80, -93, -20, 34, 115, 75, 37,
140 47, 16, 6, -92, -25, 37, 69, 82, -61, -100, -85, -51, 6, -95, 58
141 };
142
143 const int32_t conv_4_output_mult[3] = {2039209398, 2005068758, 2023002003};
144
145 const int32_t conv_4_output_shift[3] = {-9, -9, -9};
146
147 const int8_t conv_4_output_ref[36] = {-5, -39, -31, 20, -37, -26, -109, -7, -10, -51, -58, 48,
148 -100, -32, 24, 4, 69, -38, -64, 65, -34, 95, -55, 39,
149 95, -54, 27, -49, 25, -68, -109, -66, 72, 38, -44, -40};
150
ZTEST(cmsis_nn,test_convolve)151 ZTEST(cmsis_nn, test_convolve)
152 {
153 int8_t output[CONV_4_DST_SIZE] = {0};
154
155 cmsis_nn_context ctx;
156 cmsis_nn_conv_params conv_params;
157 cmsis_nn_per_channel_quant_params quant_params;
158 cmsis_nn_dims input_dims;
159 cmsis_nn_dims filter_dims;
160 cmsis_nn_dims bias_dims;
161 cmsis_nn_dims output_dims;
162
163 const int32_t *bias_data = conv_4_biases;
164 const int8_t *kernel_data = conv_4_weights;
165 const int8_t *input_data = conv_4_input;
166
167 input_dims.n = CONV_4_INPUT_BATCHES;
168 input_dims.w = CONV_4_INPUT_W;
169 input_dims.h = CONV_4_INPUT_H;
170 input_dims.c = CONV_4_IN_CH;
171 filter_dims.w = CONV_4_FILTER_X;
172 filter_dims.h = CONV_4_FILTER_Y;
173 filter_dims.c = CONV_4_IN_CH;
174 output_dims.w = CONV_4_OUTPUT_W;
175 output_dims.h = CONV_4_OUTPUT_H;
176 output_dims.c = CONV_4_OUT_CH;
177
178 conv_params.padding.w = CONV_4_PAD_X;
179 conv_params.padding.h = CONV_4_PAD_Y;
180 conv_params.stride.w = CONV_4_STRIDE_X;
181 conv_params.stride.h = CONV_4_STRIDE_Y;
182 conv_params.dilation.w = CONV_4_DILATION_X;
183 conv_params.dilation.h = CONV_4_DILATION_Y;
184
185 conv_params.input_offset = CONV_4_INPUT_OFFSET;
186 conv_params.output_offset = CONV_4_OUTPUT_OFFSET;
187 conv_params.activation.min = CONV_4_OUT_ACTIVATION_MIN;
188 conv_params.activation.max = CONV_4_OUT_ACTIVATION_MAX;
189 quant_params.multiplier = (int32_t *)conv_4_output_mult;
190 quant_params.shift = (int32_t *)conv_4_output_shift;
191
192 int32_t buf_size = arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims);
193
194 ctx.buf = malloc(buf_size);
195 ctx.size = 0;
196
197 arm_cmsis_nn_status result = arm_convolve_s8(&ctx,
198 &conv_params,
199 &quant_params,
200 &input_dims,
201 input_data,
202 &filter_dims,
203 kernel_data,
204 &bias_dims,
205 bias_data,
206 &output_dims,
207 output);
208
209 free(ctx.buf);
210 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
211 zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
212
213 buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
214 &filter_dims, &output_dims);
215 ctx.buf = malloc(buf_size);
216 ctx.size = 0;
217
218 result = arm_convolve_wrapper_s8(&ctx,
219 &conv_params,
220 &quant_params,
221 &input_dims,
222 input_data,
223 &filter_dims,
224 kernel_data,
225 &bias_dims,
226 bias_data,
227 &output_dims,
228 output);
229
230 free(ctx.buf);
231 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
232 zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
233 }
234
235 #define STRIDE2PAD1_OUT_CH 1
236 #define STRIDE2PAD1_IN_CH 1
237 #define STRIDE2PAD1_INPUT_W 7
238 #define STRIDE2PAD1_INPUT_H 7
239 #define STRIDE2PAD1_DST_SIZE 16
240 #define STRIDE2PAD1_INPUT_SIZE 49
241 #define STRIDE2PAD1_OUT_ACTIVATION_MIN -128
242 #define STRIDE2PAD1_OUT_ACTIVATION_MAX 127
243 #define STRIDE2PAD1_INPUT_BATCHES 1
244 #define STRIDE2PAD1_FILTER_X 3
245 #define STRIDE2PAD1_FILTER_Y 3
246 #define STRIDE2PAD1_STRIDE_X 2
247 #define STRIDE2PAD1_STRIDE_Y 2
248 #define STRIDE2PAD1_PAD_X 1
249 #define STRIDE2PAD1_PAD_Y 1
250 #define STRIDE2PAD1_OUTPUT_W 4
251 #define STRIDE2PAD1_OUTPUT_H 4
252 #define STRIDE2PAD1_INPUT_OFFSET 128
253 #define STRIDE2PAD1_OUTPUT_OFFSET -20
254 #define STRIDE2PAD1_DILATION_X 1
255 #define STRIDE2PAD1_DILATION_Y 1
256
257 const int32_t stride2pad1_biases[1] = {-9794};
258
259 const int8_t stride2pad1_weights[9] = {-54, 57, -19, -127, 87, 70, 74, -110, 66};
260
261 const int8_t stride2pad1_input[49] = {
262 -91, -30, -57, -76, 32, -13, 14, -96, 108, -4, 41, 48, 107, -68, -101, 30, 95,
263 95, 91, -66, -80, 114, -49, 7, -67, -35, -1, -88, -77, -56, -103, 5, -39, -118,
264 -24, -32, 67, 11, 38, -16, -124, 44, -46, -92, -24, 108, 80, -29, -3};
265
266 const int32_t stride2pad1_output_mult[1] = {2033801520};
267
268 const int32_t stride2pad1_output_shift[1] = {-8};
269
270 const int8_t stride2pad1_output_ref[16] = {26, -11, 33, -25, -96, -52, -78, -86,
271 33, -2, -88, -113, -14, 0, -84, -27};
272
ZTEST(cmsis_nn,test_depthwise_convolve)273 ZTEST(cmsis_nn, test_depthwise_convolve)
274 {
275 int8_t output[STRIDE2PAD1_DST_SIZE] = {0};
276
277 cmsis_nn_context ctx;
278 cmsis_nn_dw_conv_params dw_conv_params;
279 cmsis_nn_per_channel_quant_params quant_params;
280 cmsis_nn_dims input_dims;
281 cmsis_nn_dims filter_dims;
282 cmsis_nn_dims bias_dims = {0};
283 cmsis_nn_dims output_dims;
284
285 const int32_t *bias_data = stride2pad1_biases;
286 const int8_t *kernel_data = stride2pad1_weights;
287 const int8_t *input_data = stride2pad1_input;
288
289 input_dims.n = STRIDE2PAD1_INPUT_BATCHES;
290 input_dims.w = STRIDE2PAD1_INPUT_W;
291 input_dims.h = STRIDE2PAD1_INPUT_H;
292 input_dims.c = STRIDE2PAD1_IN_CH;
293 filter_dims.w = STRIDE2PAD1_FILTER_X;
294 filter_dims.h = STRIDE2PAD1_FILTER_Y;
295 output_dims.w = STRIDE2PAD1_OUTPUT_W;
296 output_dims.h = STRIDE2PAD1_OUTPUT_H;
297 output_dims.c = STRIDE2PAD1_OUT_CH;
298
299 dw_conv_params.padding.w = STRIDE2PAD1_PAD_X;
300 dw_conv_params.padding.h = STRIDE2PAD1_PAD_Y;
301 dw_conv_params.stride.w = STRIDE2PAD1_STRIDE_X;
302 dw_conv_params.stride.h = STRIDE2PAD1_STRIDE_Y;
303 dw_conv_params.dilation.w = STRIDE2PAD1_DILATION_X;
304 dw_conv_params.dilation.h = STRIDE2PAD1_DILATION_Y;
305
306 dw_conv_params.ch_mult = 1;
307
308 dw_conv_params.input_offset = STRIDE2PAD1_INPUT_OFFSET;
309 dw_conv_params.output_offset = STRIDE2PAD1_OUTPUT_OFFSET;
310 dw_conv_params.activation.min = STRIDE2PAD1_OUT_ACTIVATION_MIN;
311 dw_conv_params.activation.max = STRIDE2PAD1_OUT_ACTIVATION_MAX;
312 quant_params.multiplier = (int32_t *)stride2pad1_output_mult;
313 quant_params.shift = (int32_t *)stride2pad1_output_shift;
314
315 ctx.buf = NULL;
316 ctx.size = 0;
317
318 arm_cmsis_nn_status result = arm_depthwise_conv_s8(&ctx,
319 &dw_conv_params,
320 &quant_params,
321 &input_dims,
322 input_data,
323 &filter_dims,
324 kernel_data,
325 &bias_dims,
326 bias_data,
327 &output_dims,
328 output);
329
330 free(ctx.buf);
331 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
332 zassert_mem_equal(stride2pad1_output_ref, output, sizeof(output), "");
333 }
334
335 #define FULLY_CONNECTED_MVE_0_OUT_CH 9
336 #define FULLY_CONNECTED_MVE_0_IN_CH 16
337 #define FULLY_CONNECTED_MVE_0_INPUT_W 1
338 #define FULLY_CONNECTED_MVE_0_INPUT_H 1
339 #define FULLY_CONNECTED_MVE_0_DST_SIZE 9
340 #define FULLY_CONNECTED_MVE_0_INPUT_SIZE 16
341 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN -128
342 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX 127
343 #define FULLY_CONNECTED_MVE_0_INPUT_BATCHES 1
344 #define FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER 1244038257
345 #define FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT -9
346 #define FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH 16
347 #define FULLY_CONNECTED_MVE_0_INPUT_OFFSET 128
348 #define FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET -26
349
350 const int32_t fully_connected_mve_0_biases[9] = {11295, -30752, -3196, 10489, -5120,
351 18598, 27393, 29746, 22967};
352
353 const int8_t fully_connected_mve_0_input[16] = {-43, 68, 79, -12, -119, -56, -102, -46,
354 107, -65, -109, -7, 92, -99, -80, -29};
355
356 const int8_t fully_connected_mve_0_output_ref[9] = {-9, -3, 26, 8, 3, -88, 75, 34, 5};
357
358 const int8_t fully_connected_mve_0_weights[144] = {
359 37, -46, 75, -33, -52, -82, -94, 64, 71, 65, 64, 16, -66, -5, -65, -44,
360 82, 42, 84, 105, 18, 79, -103, -75, -95, 65, 87, 103, 43, -25, -66, 75,
361 125, 40, -34, 24, 9, -79, 4, 73, 98, -75, 42, 81, 18, -58, -119, 92,
362 0, -72, 48, 23, -69, 11, -95, -103, 66, 117, 107, -96, 114, -29, 75, -93,
363 118, 66, -19, 83, -14, 86, -110, 44, 37, -9, 17, -107, 50, -116, -116, -27,
364 -84, -126, -108, -127, -71, 8, 81, 108, -61, 126, 69, -45, 37, -78, -102, -55,
365 116, 112, -111, -89, -57, 82, -47, 22, 125, -84, 97, -9, 88, 74, -15, 118,
366 -95, 112, 89, 44, -17, -112, -71, -94, 1, -117, 112, -92, 52, 57, -22, 80,
367 -60, 95, -106, -1, -27, 105, 6, 123, 6, 96, 126, -65, -29, 103, 19, -45};
368
ZTEST(cmsis_nn,test_fully_connected)369 ZTEST(cmsis_nn, test_fully_connected)
370 {
371 int8_t output[FULLY_CONNECTED_MVE_0_DST_SIZE] = {0};
372
373 cmsis_nn_context ctx;
374 cmsis_nn_fc_params fc_params;
375 cmsis_nn_per_tensor_quant_params quant_params;
376 cmsis_nn_dims input_dims;
377 cmsis_nn_dims filter_dims;
378 cmsis_nn_dims bias_dims;
379 cmsis_nn_dims output_dims;
380
381 const int32_t *bias_data = fully_connected_mve_0_biases;
382 const int8_t *kernel_data = fully_connected_mve_0_weights;
383 const int8_t *input_data = fully_connected_mve_0_input;
384
385 input_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
386 input_dims.w = FULLY_CONNECTED_MVE_0_INPUT_W;
387 input_dims.h = FULLY_CONNECTED_MVE_0_INPUT_H;
388 input_dims.c = FULLY_CONNECTED_MVE_0_IN_CH;
389 filter_dims.n = FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH;
390 filter_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
391 output_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
392 output_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
393
394 fc_params.input_offset = FULLY_CONNECTED_MVE_0_INPUT_OFFSET;
395 fc_params.filter_offset = 0;
396 fc_params.output_offset = FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET;
397 fc_params.activation.min = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN;
398 fc_params.activation.max = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX;
399
400 quant_params.multiplier = FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER;
401 quant_params.shift = FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT;
402
403 int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
404
405 ctx.buf = malloc(buf_size);
406 ctx.size = buf_size;
407 arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
408 &fc_params,
409 &quant_params,
410 &input_dims,
411 input_data,
412 &filter_dims,
413 kernel_data,
414 &bias_dims,
415 bias_data,
416 &output_dims,
417 output);
418
419 free(ctx.buf);
420 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
421 zassert_mem_equal(fully_connected_mve_0_output_ref, output, sizeof(output), "");
422 }
423
424 #define MAXPOOLING_2_OUT_CH 5
425 #define MAXPOOLING_2_IN_CH 5
426 #define MAXPOOLING_2_INPUT_W 12
427 #define MAXPOOLING_2_INPUT_H 1
428 #define MAXPOOLING_2_DST_SIZE 60
429 #define MAXPOOLING_2_INPUT_SIZE 60
430 #define MAXPOOLING_2_OUT_ACTIVATION_MIN -128
431 #define MAXPOOLING_2_OUT_ACTIVATION_MAX 127
432 #define MAXPOOLING_2_INPUT_BATCHES 1
433 #define MAXPOOLING_2_FILTER_X 3
434 #define MAXPOOLING_2_FILTER_Y 1
435 #define MAXPOOLING_2_STRIDE_X 1
436 #define MAXPOOLING_2_STRIDE_Y 2
437 #define MAXPOOLING_2_PAD_X 1
438 #define MAXPOOLING_2_PAD_Y 0
439 #define MAXPOOLING_2_OUTPUT_W 12
440 #define MAXPOOLING_2_OUTPUT_H 1
441
442 const int8_t maxpooling_2_input[60] = {
443 75, -52, -42, -30, 56, 64, 106, -36, 120, -3, 34, -105, 69, 75, -39,
444 15, 93, -71, 39, 34, -11, 65, 22, 59, 106, 105, 45, -116, -75, 123,
445 -65, 75, -61, 13, -25, -123, 59, 110, -65, 86, -108, -107, -17, 38, 27,
446 -1, -115, -123, 75, -75, 68, 52, 12, -35, 116, -68, 22, 15, 76, -81};
447
448 const int8_t maxpooling_2_output_ref[60] = {
449 75, 106, -36, 120, 56, 75, 106, 69, 120, 56, 64, 106, 69, 120, 34,
450 34, 93, 69, 75, 106, 105, 93, 22, 59, 123, 105, 75, 22, 59, 123,
451 105, 75, 110, 13, 123, -65, 75, 110, 38, 86, -1, 59, 110, 75, 86,
452 68, 52, 12, 75, 116, 68, 52, 15, 76, 116, 68, 52, 15, 76, 116};
453
ZTEST(cmsis_nn,test_max_pool)454 ZTEST(cmsis_nn, test_max_pool)
455 {
456 int8_t output[MAXPOOLING_2_DST_SIZE] = {0};
457
458 cmsis_nn_context ctx;
459 cmsis_nn_pool_params pool_params;
460 cmsis_nn_dims input_dims;
461 cmsis_nn_dims filter_dims;
462 cmsis_nn_dims output_dims;
463
464 const int8_t *input_data = maxpooling_2_input;
465
466 input_dims.n = MAXPOOLING_2_INPUT_BATCHES;
467 input_dims.w = MAXPOOLING_2_INPUT_W;
468 input_dims.h = MAXPOOLING_2_INPUT_H;
469 input_dims.c = MAXPOOLING_2_IN_CH;
470 filter_dims.w = MAXPOOLING_2_FILTER_X;
471 filter_dims.h = MAXPOOLING_2_FILTER_Y;
472 output_dims.w = MAXPOOLING_2_OUTPUT_W;
473 output_dims.h = MAXPOOLING_2_OUTPUT_H;
474 output_dims.c = MAXPOOLING_2_OUT_CH;
475
476 pool_params.padding.w = MAXPOOLING_2_PAD_X;
477 pool_params.padding.h = MAXPOOLING_2_PAD_Y;
478 pool_params.stride.w = MAXPOOLING_2_STRIDE_X;
479 pool_params.stride.h = MAXPOOLING_2_STRIDE_Y;
480
481 pool_params.activation.min = MAXPOOLING_2_OUT_ACTIVATION_MIN;
482 pool_params.activation.max = MAXPOOLING_2_OUT_ACTIVATION_MAX;
483
484 for (int i = 0; i < REPEAT_NUM; i++) {
485 arm_cmsis_nn_status result =
486 arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
487 &output_dims, output);
488
489 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
490 zassert_mem_equal(maxpooling_2_output_ref, output, sizeof(output), "");
491 }
492 }
493
494 #define SOFTMAX_NUM_ROWS 2
495 #define SOFTMAX_ROW_SIZE 5
496 #define SOFTMAX_INPUT_MULT 1077952640
497 #define SOFTMAX_INPUT_LEFT_SHIFT 19
498 #define SOFTMAX_DIFF_MIN -3968
499 #define SOFTMAX_DST_SIZE 10
500
501 const int8_t softmax_input[10] = {101, 49, 6, -34, -75, -79, -38, 120, -55, 115};
502
503 const int8_t softmax_output_ref[10] = {-57, -70, -79, -86, -92, -94, -88, -54, -91, -56};
504
ZTEST(cmsis_nn,test_softmax)505 ZTEST(cmsis_nn, test_softmax)
506 {
507 const int32_t num_rows = SOFTMAX_NUM_ROWS;
508 const int32_t row_size = SOFTMAX_ROW_SIZE;
509 const int32_t mult = SOFTMAX_INPUT_MULT;
510 const int32_t shift = SOFTMAX_INPUT_LEFT_SHIFT;
511 const int32_t diff_min = SOFTMAX_DIFF_MIN;
512 const int8_t *input_data = softmax_input;
513 int8_t output[SOFTMAX_DST_SIZE];
514
515 for (int i = 0; i < REPEAT_NUM; i++) {
516 arm_softmax_s8(input_data, num_rows, row_size, mult, shift, diff_min, output);
517 zassert_mem_equal(softmax_output_ref, output, sizeof(output), "");
518 }
519 }
520
521 #define SVDF_2_MULTIPLIER_IN 1717987072
522 #define SVDF_2_MULTIPLIER_OUT 1099511552
523 #define SVDF_2_SHIFT_1 -3
524 #define SVDF_2_SHIFT_2 -11
525 #define SVDF_2_IN_ACTIVATION_MIN -32768
526 #define SVDF_2_IN_ACTIVATION_MAX 32767
527 #define SVDF_2_RANK 2
528 #define SVDF_2_FEATURE_BATCHES 10
529 #define SVDF_2_TIME_BATCHES 2
530 #define SVDF_2_INPUT_SIZE 7
531 #define SVDF_2_DST_SIZE 15
532 #define SVDF_2_OUT_ACTIVATION_MIN -128
533 #define SVDF_2_OUT_ACTIVATION_MAX 127
534 #define SVDF_2_INPUT_BATCHES 3
535 #define SVDF_2_INPUT_OFFSET 0
536 #define SVDF_2_OUTPUT_OFFSET 0
537
538 const int32_t svdf_2_biases[5] = {0, 0, 0, 0, 0};
539
540 const int16_t svdf_2_state[60] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
541 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
542 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
543
544 const int8_t svdf_2_weights_feature[70] = {
545 27, 82, -108, -127, 85, 3, -51, 32, 110, -6, -14, -16, 31, 101,
546 -122, 19, 76, 74, -80, 12, -22, -17, 10, -28, 55, 109, 2, -107,
547 -4, 72, -65, -59, 36, -69, 105, -97, 25, 38, 110, -121, -88, -126,
548 -14, 16, -88, -66, 3, -93, 69, -64, 44, 103, 95, -95, 68, -46,
549 106, -31, -63, 23, -38, 36, -95, -43, 93, 77, 91, -26, 33, 59};
550
551 const int16_t svdf_2_weights_time[20] = {-31, -88, -10, -72, -119, -6, -70, 63, -10, 93,
552 5, 42, -6, 22, 6, 51, 37, -38, 5, 117};
553
554 const int8_t svdf_2_input_sequence[42] = {
555 29, 81, -38, 17, -116, 43, 119, -127, 74, 115, 9, 118, 7, -56,
556 -53, -14, -98, 60, -128, 10, 28, -18, 12, -28, -126, 87, -115, -44,
557 -123, -109, -59, -87, -69, 121, -128, -95, -70, 2, 81, -119, 84, -122};
558
559 const int8_t svdf_2_output_ref[15] = {-53, 45, 27, -24, -53, 26, -82, -38,
560 11, -85, 94, -16, -32, 31, 4};
561
check_null_bias(const int32_t * bias,int32_t size)562 static bool check_null_bias(const int32_t *bias, int32_t size)
563 {
564 bool null_bias = true;
565
566 for (int i = 0; i < size; i++) {
567 if (bias[i] != 0) {
568 null_bias = false;
569 break;
570 }
571 }
572 return null_bias;
573 }
574
ZTEST(cmsis_nn,test_svdf)575 ZTEST(cmsis_nn, test_svdf)
576 {
577 cmsis_nn_context input_ctx;
578 cmsis_nn_context output_ctx;
579 cmsis_nn_svdf_params svdf_2_params;
580 cmsis_nn_dims input_dims;
581 cmsis_nn_dims weights_feature_dims;
582 cmsis_nn_dims weights_time_dims;
583 cmsis_nn_dims state_dims;
584 cmsis_nn_dims output_dims;
585 cmsis_nn_dims bias_dims;
586 cmsis_nn_per_tensor_quant_params input_quant_params;
587 cmsis_nn_per_tensor_quant_params output_quant_params;
588 int8_t output_data[SVDF_2_DST_SIZE];
589
590 const int8_t *weights_feature_data = svdf_2_weights_feature;
591 const int16_t *weights_time_data = svdf_2_weights_time;
592
593 input_dims.n = SVDF_2_INPUT_BATCHES;
594 input_dims.h = SVDF_2_INPUT_SIZE;
595 weights_feature_dims.n = SVDF_2_FEATURE_BATCHES;
596 weights_time_dims.h = SVDF_2_TIME_BATCHES;
597
598 input_quant_params.multiplier = SVDF_2_MULTIPLIER_IN;
599 input_quant_params.shift = SVDF_2_SHIFT_1;
600 output_quant_params.multiplier = SVDF_2_MULTIPLIER_OUT;
601 output_quant_params.shift = SVDF_2_SHIFT_2;
602
603 svdf_2_params.input_activation.min = SVDF_2_IN_ACTIVATION_MIN;
604 svdf_2_params.input_activation.max = SVDF_2_IN_ACTIVATION_MAX;
605 svdf_2_params.output_activation.min = SVDF_2_OUT_ACTIVATION_MIN;
606 svdf_2_params.output_activation.max = SVDF_2_OUT_ACTIVATION_MAX;
607 svdf_2_params.input_offset = SVDF_2_INPUT_OFFSET;
608 svdf_2_params.output_offset = SVDF_2_OUTPUT_OFFSET;
609 svdf_2_params.rank = SVDF_2_RANK;
610
611 const int input_round_size = SVDF_2_INPUT_BATCHES * SVDF_2_INPUT_SIZE;
612 const int number_inputs = sizeof(svdf_2_input_sequence) / input_round_size;
613 const int32_t number_units = SVDF_2_FEATURE_BATCHES / SVDF_2_RANK;
614 const int scratch_size = SVDF_2_INPUT_BATCHES * SVDF_2_FEATURE_BATCHES * sizeof(int32_t);
615 const int scratch_size_out = SVDF_2_INPUT_BATCHES * number_units * sizeof(int32_t);
616
617 input_ctx.buf = malloc(scratch_size);
618 output_ctx.buf = malloc(scratch_size_out);
619
620 int8_t *input_data = malloc(input_round_size);
621 int16_t *state_data = malloc(sizeof(svdf_2_state));
622 const bool null_bias = check_null_bias(svdf_2_biases,
623 SVDF_2_DST_SIZE / SVDF_2_INPUT_BATCHES);
624
625 for (int i = 0; i < REPEAT_NUM; i++) {
626 memcpy(state_data, svdf_2_state, sizeof(svdf_2_state));
627 for (int j = 0; j < number_inputs; j++) {
628 memcpy(input_data, svdf_2_input_sequence + j * input_round_size,
629 input_round_size);
630 arm_cmsis_nn_status result = arm_svdf_state_s16_s8(&input_ctx,
631 &output_ctx,
632 &svdf_2_params,
633 &input_quant_params,
634 &output_quant_params,
635 &input_dims,
636 input_data,
637 &state_dims,
638 state_data,
639 &weights_feature_dims,
640 weights_feature_data,
641 &weights_time_dims,
642 weights_time_data,
643 &bias_dims,
644 null_bias == true ? NULL : svdf_2_biases,
645 &output_dims,
646 output_data);
647 zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
648 }
649
650 zassert_mem_equal(svdf_2_output_ref, output_data, sizeof(output_data), "");
651 }
652 free(state_data);
653 free(input_data);
654 free(input_ctx.buf);
655 free(output_ctx.buf);
656 }
657
658 ZTEST_SUITE(cmsis_nn, NULL, NULL, NULL, NULL, NULL);
659