1 /*
2  * Copyright (c) 2021, Commonwealth Scientific and Industrial Research
3  * Organisation (CSIRO) ABN 41 687 119 230.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * This is not exhaustive functional testing of the CMSIS-NN library.
8  *
9  * Individual tests have been pulled from CMSIS/NN/Tests/UnitTest to
10  * validate the integration of CMSIS-NN and Zephyr
11  */
12 
13 #include <zephyr/ztest.h>
14 #include <zephyr/kernel.h>
15 #include <stdlib.h>
16 
17 #include "arm_nnfunctions.h"
18 
19 #define REPEAT_NUM 3
20 
21 #define AVGPOOLING_2_OUT_CH		5
22 #define AVGPOOLING_2_IN_CH		5
23 #define AVGPOOLING_2_INPUT_W		12
24 #define AVGPOOLING_2_INPUT_H		1
25 #define AVGPOOLING_2_DST_SIZE		60
26 #define AVGPOOLING_2_INPUT_SIZE		60
27 #define AVGPOOLING_2_OUT_ACTIVATION_MIN -128
28 #define AVGPOOLING_2_OUT_ACTIVATION_MAX 127
29 #define AVGPOOLING_2_INPUT_BATCHES	1
30 #define AVGPOOLING_2_FILTER_X		3
31 #define AVGPOOLING_2_FILTER_Y		1
32 #define AVGPOOLING_2_STRIDE_X		1
33 #define AVGPOOLING_2_STRIDE_Y		2
34 #define AVGPOOLING_2_PAD_X		1
35 #define AVGPOOLING_2_PAD_Y		0
36 #define AVGPOOLING_2_OUTPUT_W		12
37 #define AVGPOOLING_2_OUTPUT_H		1
38 
39 const int8_t avgpooling_2_input[60] = {
40 	-82, -104, 10,	-28, -52, -51, -66, 52,	 124, -74, -21,	 4,  37,   -7,	-33,
41 	102, 110,  24,	52,  121, 13,  -55, -79, -92, -35, -103, 86, 95,   46,	32,
42 	-24, -123, 120, 29,  -77, -97, -69, -68, 58,  38,  3,	 3,  79,   -47, 112,
43 	-52, -113, -46, 107, 68,  83,  -70, 91,	 14,  113, 74,	 73, -103, -98, 25};
44 
45 const int8_t avgpooling_2_output_ref[60] = {
46 	-67, -85, 31, 48,  -63, -51, -55, 33,  30, -53, 10,  16,  38,  56,  5,
47 	31,  20,  -6, -16, 18,	4,   47,  13,  2,  39,	-38, -31, 45,  -6,  -27,
48 	-75, -35, 49, 44,  -2,	-39, -63, 44,  13, 24,	-49, -60, -12, 39,  73,
49 	11,  -60, 41, 25,  98,	35,  -37, -19, 8,  69,	79,  2,	  -6,  -42, 69};
50 
ZTEST(cmsis_nn,test_avgpool)51 ZTEST(cmsis_nn, test_avgpool)
52 {
53 	int8_t output[AVGPOOLING_2_DST_SIZE] = {0};
54 
55 	cmsis_nn_context ctx;
56 	cmsis_nn_pool_params pool_params;
57 	cmsis_nn_dims input_dims;
58 	cmsis_nn_dims filter_dims;
59 	cmsis_nn_dims output_dims;
60 
61 	input_dims.n = AVGPOOLING_2_INPUT_BATCHES;
62 	input_dims.w = AVGPOOLING_2_INPUT_W;
63 	input_dims.h = AVGPOOLING_2_INPUT_H;
64 	input_dims.c = AVGPOOLING_2_IN_CH;
65 	filter_dims.w = AVGPOOLING_2_FILTER_X;
66 	filter_dims.h = AVGPOOLING_2_FILTER_Y;
67 	output_dims.w = AVGPOOLING_2_OUTPUT_W;
68 	output_dims.h = AVGPOOLING_2_OUTPUT_H;
69 	output_dims.c = AVGPOOLING_2_OUT_CH;
70 
71 	pool_params.padding.w = AVGPOOLING_2_PAD_X;
72 	pool_params.padding.h = AVGPOOLING_2_PAD_Y;
73 	pool_params.stride.w = AVGPOOLING_2_STRIDE_X;
74 	pool_params.stride.h = AVGPOOLING_2_STRIDE_Y;
75 
76 	pool_params.activation.min = AVGPOOLING_2_OUT_ACTIVATION_MIN;
77 	pool_params.activation.max = AVGPOOLING_2_OUT_ACTIVATION_MAX;
78 
79 	ctx.size = arm_avgpool_s8_get_buffer_size(AVGPOOLING_2_OUTPUT_W, AVGPOOLING_2_IN_CH);
80 	ctx.buf = malloc(ctx.size);
81 
82 	arm_cmsis_nn_status result = arm_avgpool_s8(&ctx,
83 						   &pool_params,
84 						   &input_dims,
85 						   avgpooling_2_input,
86 						   &filter_dims,
87 						   &output_dims,
88 						   output);
89 
90 	free(ctx.buf);
91 
92 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
93 	zassert_mem_equal(avgpooling_2_output_ref, output, sizeof(output), "");
94 }
95 
96 #define CONV_4_OUT_CH		  3
97 #define CONV_4_IN_CH		  3
98 #define CONV_4_INPUT_W		  5
99 #define CONV_4_INPUT_H		  5
100 #define CONV_4_DST_SIZE		  36
101 #define CONV_4_INPUT_SIZE	  75
102 #define CONV_4_OUT_ACTIVATION_MIN -109
103 #define CONV_4_OUT_ACTIVATION_MAX 127
104 #define CONV_4_INPUT_BATCHES	  3
105 #define CONV_4_FILTER_X		  2
106 #define CONV_4_FILTER_Y		  3
107 #define CONV_4_STRIDE_X		  2
108 #define CONV_4_STRIDE_Y		  2
109 #define CONV_4_PAD_X		  0
110 #define CONV_4_PAD_Y		  0
111 #define CONV_4_OUTPUT_W		  2
112 #define CONV_4_OUTPUT_H		  2
113 #define CONV_4_INPUT_OFFSET	  128
114 #define CONV_4_OUTPUT_OFFSET	  -128
115 #define CONV_4_DILATION_X	  1
116 #define CONV_4_DILATION_Y	  1
117 
118 const int32_t conv_4_biases[3] = {13175, 9050, 18215};
119 
120 const int8_t conv_4_weights[54] = {
121 	-25, -83, -74,	105, 30,  118, -32, 127, 34,  127, -112, 39, -43, 104, 41,  -124, 115, 5,
122 	42,  -48, -119, 93,  17,  57,  41,  -41, -42, 23,  127,	 18, 70,  -99, 71,  67,	  83,  76,
123 	-50, 98,  66,	64,  127, -6,  -77, -48, -26, 45,  77,	 1,  81,  27,  124, -103, 37,  36};
124 
125 const int8_t conv_4_input[225] = {
126 	82,   120,  -97, -44,  -118, 73,   4,	 -84,  -53,  -122, -15,	 77,   83,  43,	  37,
127 	85,   -11,  103, 45,   -69,  -12,  -8,	 21,   6,    -68,  -83,	 -15,  -99, 90,	  -62,
128 	95,   62,   -38, -32,  -35,  -105, -53,	 70,   112,  14,   -4,	 -33,  -26, -93,  -98,
129 	22,   -5,   22,	 -104, 57,   -92,  30,	 -62,  0,    -43,  -82,	 60,   99,  -83,  32,
130 	94,   49,   10,	 112,  -71,  -27,  -91,	 -79,  52,   -92,  -71,	 86,   -79, -15,  -80,
131 	-74,  -4,   76,	 -119, 91,   -23,  -12,	 -111, -72,  26,   11,	 64,   116, 38,	  99,
132 	125,  17,   6,	 -4,   46,   119,  113,	 -116, -125, 80,   -57,	 122,  75,  119,  -117,
133 	87,   -121, -70, -75,  -127, 16,   -124, -110, 10,   71,   29,	 27,   37,  -24,  52,
134 	28,   -100, 86,	 -75,  117,  -31,  -115, -86,  -122, 121,  -96,	 -118, 32,  111,  25,
135 	-90,  -8,   110, 37,   35,   124,  -123, 94,   -122, -114, 37,	 85,   -36, 53,	  -40,
136 	73,   -99,  27,	 10,   37,   41,   64,	 -97,  -123, 75,   0,	 -107, -72, 58,	  -100,
137 	17,   77,   114, 120,  -83,  -96,  75,	 -12,  -27,  3,	   35,	 85,   4,   119,  -20,
138 	28,   99,   104, -78,  -51,  -82,  -92,	 -40,  -116, 35,   -107, 39,   9,   -120, -50,
139 	-102, -114, 25,	 -77,  25,   7,	   64,	 110,  80,   -93,  -20,	 34,   115, 75,	  37,
140 	47,   16,   6,	 -92,  -25,  37,   69,	 82,   -61,  -100, -85,	 -51,  6,   -95,  58
141 };
142 
143 const int32_t conv_4_output_mult[3] = {2039209398, 2005068758, 2023002003};
144 
145 const int32_t conv_4_output_shift[3] = {-9, -9, -9};
146 
147 const int8_t conv_4_output_ref[36] = {-5,   -39, -31, 20,  -37, -26, -109, -7,	-10, -51, -58, 48,
148 				      -100, -32, 24,  4,   69,	-38, -64,  65,	-34, 95,  -55, 39,
149 				      95,   -54, 27,  -49, 25,	-68, -109, -66, 72,  38,  -44, -40};
150 
ZTEST(cmsis_nn,test_convolve)151 ZTEST(cmsis_nn, test_convolve)
152 {
153 	int8_t output[CONV_4_DST_SIZE] = {0};
154 
155 	cmsis_nn_context ctx;
156 	cmsis_nn_conv_params conv_params;
157 	cmsis_nn_per_channel_quant_params quant_params;
158 	cmsis_nn_dims input_dims;
159 	cmsis_nn_dims filter_dims;
160 	cmsis_nn_dims bias_dims;
161 	cmsis_nn_dims output_dims;
162 
163 	const int32_t *bias_data = conv_4_biases;
164 	const int8_t *kernel_data = conv_4_weights;
165 	const int8_t *input_data = conv_4_input;
166 
167 	input_dims.n = CONV_4_INPUT_BATCHES;
168 	input_dims.w = CONV_4_INPUT_W;
169 	input_dims.h = CONV_4_INPUT_H;
170 	input_dims.c = CONV_4_IN_CH;
171 	filter_dims.w = CONV_4_FILTER_X;
172 	filter_dims.h = CONV_4_FILTER_Y;
173 	output_dims.w = CONV_4_OUTPUT_W;
174 	output_dims.h = CONV_4_OUTPUT_H;
175 	output_dims.c = CONV_4_OUT_CH;
176 
177 	conv_params.padding.w = CONV_4_PAD_X;
178 	conv_params.padding.h = CONV_4_PAD_Y;
179 	conv_params.stride.w = CONV_4_STRIDE_X;
180 	conv_params.stride.h = CONV_4_STRIDE_Y;
181 	conv_params.dilation.w = CONV_4_DILATION_X;
182 	conv_params.dilation.h = CONV_4_DILATION_Y;
183 
184 	conv_params.input_offset = CONV_4_INPUT_OFFSET;
185 	conv_params.output_offset = CONV_4_OUTPUT_OFFSET;
186 	conv_params.activation.min = CONV_4_OUT_ACTIVATION_MIN;
187 	conv_params.activation.max = CONV_4_OUT_ACTIVATION_MAX;
188 	quant_params.multiplier = (int32_t *)conv_4_output_mult;
189 	quant_params.shift = (int32_t *)conv_4_output_shift;
190 
191 	int32_t buf_size = arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims);
192 
193 	ctx.buf = malloc(buf_size);
194 	ctx.size = 0;
195 
196 	arm_cmsis_nn_status result = arm_convolve_s8(&ctx,
197 					    &conv_params,
198 					    &quant_params,
199 					    &input_dims,
200 					    input_data,
201 					    &filter_dims,
202 					    kernel_data,
203 					    &bias_dims,
204 					    bias_data,
205 					    &output_dims,
206 					    output);
207 
208 	free(ctx.buf);
209 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
210 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
211 
212 	buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
213 							   &filter_dims, &output_dims);
214 	ctx.buf = malloc(buf_size);
215 	ctx.size = 0;
216 
217 	result = arm_convolve_wrapper_s8(&ctx,
218 					 &conv_params,
219 					 &quant_params,
220 					 &input_dims,
221 					 input_data,
222 					 &filter_dims,
223 					 kernel_data,
224 					 &bias_dims,
225 					 bias_data,
226 					 &output_dims,
227 					 output);
228 
229 	free(ctx.buf);
230 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
231 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
232 }
233 
234 #define STRIDE2PAD1_OUT_CH	       1
235 #define STRIDE2PAD1_IN_CH	       1
236 #define STRIDE2PAD1_INPUT_W	       7
237 #define STRIDE2PAD1_INPUT_H	       7
238 #define STRIDE2PAD1_DST_SIZE	       16
239 #define STRIDE2PAD1_INPUT_SIZE	       49
240 #define STRIDE2PAD1_OUT_ACTIVATION_MIN -128
241 #define STRIDE2PAD1_OUT_ACTIVATION_MAX 127
242 #define STRIDE2PAD1_INPUT_BATCHES      1
243 #define STRIDE2PAD1_FILTER_X	       3
244 #define STRIDE2PAD1_FILTER_Y	       3
245 #define STRIDE2PAD1_STRIDE_X	       2
246 #define STRIDE2PAD1_STRIDE_Y	       2
247 #define STRIDE2PAD1_PAD_X	       1
248 #define STRIDE2PAD1_PAD_Y	       1
249 #define STRIDE2PAD1_OUTPUT_W	       4
250 #define STRIDE2PAD1_OUTPUT_H	       4
251 #define STRIDE2PAD1_INPUT_OFFSET       128
252 #define STRIDE2PAD1_OUTPUT_OFFSET      -20
253 #define STRIDE2PAD1_DILATION_X	       1
254 #define STRIDE2PAD1_DILATION_Y	       1
255 
256 const int32_t stride2pad1_biases[1] = {-9794};
257 
258 const int8_t stride2pad1_weights[9] = {-54, 57, -19, -127, 87, 70, 74, -110, 66};
259 
260 const int8_t stride2pad1_input[49] = {
261 	-91, -30, -57, -76, 32,	 -13, 14,   -96, 108, -4,  41,	48,  107, -68,	-101, 30,  95,
262 	95,  91,  -66, -80, 114, -49, 7,    -67, -35, -1,  -88, -77, -56, -103, 5,    -39, -118,
263 	-24, -32, 67,  11,  38,	 -16, -124, 44,	 -46, -92, -24, 108, 80,  -29,	-3};
264 
265 const int32_t stride2pad1_output_mult[1] = {2033801520};
266 
267 const int32_t stride2pad1_output_shift[1] = {-8};
268 
269 const int8_t stride2pad1_output_ref[16] = {26, -11, 33,	 -25,  -96, -52, -78, -86,
270 					   33, -2,  -88, -113, -14, 0,	 -84, -27};
271 
ZTEST(cmsis_nn,test_depthwise_convolve)272 ZTEST(cmsis_nn, test_depthwise_convolve)
273 {
274 	int8_t output[STRIDE2PAD1_DST_SIZE] = {0};
275 
276 	cmsis_nn_context ctx;
277 	cmsis_nn_dw_conv_params dw_conv_params;
278 	cmsis_nn_per_channel_quant_params quant_params;
279 	cmsis_nn_dims input_dims;
280 	cmsis_nn_dims filter_dims;
281 	cmsis_nn_dims bias_dims = {0};
282 	cmsis_nn_dims output_dims;
283 
284 	const int32_t *bias_data = stride2pad1_biases;
285 	const int8_t *kernel_data = stride2pad1_weights;
286 	const int8_t *input_data = stride2pad1_input;
287 
288 	input_dims.n = STRIDE2PAD1_INPUT_BATCHES;
289 	input_dims.w = STRIDE2PAD1_INPUT_W;
290 	input_dims.h = STRIDE2PAD1_INPUT_H;
291 	input_dims.c = STRIDE2PAD1_IN_CH;
292 	filter_dims.w = STRIDE2PAD1_FILTER_X;
293 	filter_dims.h = STRIDE2PAD1_FILTER_Y;
294 	output_dims.w = STRIDE2PAD1_OUTPUT_W;
295 	output_dims.h = STRIDE2PAD1_OUTPUT_H;
296 	output_dims.c = STRIDE2PAD1_OUT_CH;
297 
298 	dw_conv_params.padding.w = STRIDE2PAD1_PAD_X;
299 	dw_conv_params.padding.h = STRIDE2PAD1_PAD_Y;
300 	dw_conv_params.stride.w = STRIDE2PAD1_STRIDE_X;
301 	dw_conv_params.stride.h = STRIDE2PAD1_STRIDE_Y;
302 	dw_conv_params.dilation.w = STRIDE2PAD1_DILATION_X;
303 	dw_conv_params.dilation.h = STRIDE2PAD1_DILATION_Y;
304 
305 	dw_conv_params.ch_mult = 1;
306 
307 	dw_conv_params.input_offset = STRIDE2PAD1_INPUT_OFFSET;
308 	dw_conv_params.output_offset = STRIDE2PAD1_OUTPUT_OFFSET;
309 	dw_conv_params.activation.min = STRIDE2PAD1_OUT_ACTIVATION_MIN;
310 	dw_conv_params.activation.max = STRIDE2PAD1_OUT_ACTIVATION_MAX;
311 	quant_params.multiplier = (int32_t *)stride2pad1_output_mult;
312 	quant_params.shift = (int32_t *)stride2pad1_output_shift;
313 
314 	ctx.buf = NULL;
315 	ctx.size = 0;
316 
317 	arm_cmsis_nn_status result = arm_depthwise_conv_s8(&ctx,
318 						  &dw_conv_params,
319 						  &quant_params,
320 						  &input_dims,
321 						  input_data,
322 						  &filter_dims,
323 						  kernel_data,
324 						  &bias_dims,
325 						  bias_data,
326 						  &output_dims,
327 						  output);
328 
329 	free(ctx.buf);
330 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
331 	zassert_mem_equal(stride2pad1_output_ref, output, sizeof(output), "");
332 }
333 
334 #define FULLY_CONNECTED_MVE_0_OUT_CH		 9
335 #define FULLY_CONNECTED_MVE_0_IN_CH		 16
336 #define FULLY_CONNECTED_MVE_0_INPUT_W		 1
337 #define FULLY_CONNECTED_MVE_0_INPUT_H		 1
338 #define FULLY_CONNECTED_MVE_0_DST_SIZE		 9
339 #define FULLY_CONNECTED_MVE_0_INPUT_SIZE	 16
340 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN -128
341 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX 127
342 #define FULLY_CONNECTED_MVE_0_INPUT_BATCHES	 1
343 #define FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER	 1244038257
344 #define FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT	 -9
345 #define FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH 16
346 #define FULLY_CONNECTED_MVE_0_INPUT_OFFSET	 128
347 #define FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET	 -26
348 
349 const int32_t fully_connected_mve_0_biases[9] = {11295, -30752, -3196, 10489, -5120,
350 						 18598, 27393,	29746, 22967};
351 
352 const int8_t fully_connected_mve_0_input[16] = {-43, 68,  79,	-12, -119, -56, -102, -46,
353 						107, -65, -109, -7,  92,   -99, -80,  -29};
354 
355 const int8_t fully_connected_mve_0_output_ref[9] = {-9, -3, 26, 8, 3, -88, 75, 34, 5};
356 
357 const int8_t fully_connected_mve_0_weights[144] = {
358 	37,  -46,  75,	 -33,  -52, -82,  -94,	64,   71,  65,	 64,  16,   -66, -5,   -65,  -44,
359 	82,  42,   84,	 105,  18,  79,	  -103, -75,  -95, 65,	 87,  103,  43,	 -25,  -66,  75,
360 	125, 40,   -34,	 24,   9,   -79,  4,	73,   98,  -75,	 42,  81,   18,	 -58,  -119, 92,
361 	0,   -72,  48,	 23,   -69, 11,	  -95,	-103, 66,  117,	 107, -96,  114, -29,  75,   -93,
362 	118, 66,   -19,	 83,   -14, 86,	  -110, 44,   37,  -9,	 17,  -107, 50,	 -116, -116, -27,
363 	-84, -126, -108, -127, -71, 8,	  81,	108,  -61, 126,	 69,  -45,  37,	 -78,  -102, -55,
364 	116, 112,  -111, -89,  -57, 82,	  -47,	22,   125, -84,	 97,  -9,   88,	 74,   -15,  118,
365 	-95, 112,  89,	 44,   -17, -112, -71,	-94,  1,   -117, 112, -92,  52,	 57,   -22,  80,
366 	-60, 95,   -106, -1,   -27, 105,  6,	123,  6,   96,	 126, -65,  -29, 103,  19,   -45};
367 
ZTEST(cmsis_nn,test_fully_connected)368 ZTEST(cmsis_nn, test_fully_connected)
369 {
370 	int8_t output[FULLY_CONNECTED_MVE_0_DST_SIZE] = {0};
371 
372 	cmsis_nn_context ctx;
373 	cmsis_nn_fc_params fc_params;
374 	cmsis_nn_per_tensor_quant_params quant_params;
375 	cmsis_nn_dims input_dims;
376 	cmsis_nn_dims filter_dims;
377 	cmsis_nn_dims bias_dims;
378 	cmsis_nn_dims output_dims;
379 
380 	const int32_t *bias_data = fully_connected_mve_0_biases;
381 	const int8_t *kernel_data = fully_connected_mve_0_weights;
382 	const int8_t *input_data = fully_connected_mve_0_input;
383 
384 	input_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
385 	input_dims.w = FULLY_CONNECTED_MVE_0_INPUT_W;
386 	input_dims.h = FULLY_CONNECTED_MVE_0_INPUT_H;
387 	input_dims.c = FULLY_CONNECTED_MVE_0_IN_CH;
388 	filter_dims.n = FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH;
389 	filter_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
390 	output_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
391 	output_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
392 
393 	fc_params.input_offset = FULLY_CONNECTED_MVE_0_INPUT_OFFSET;
394 	fc_params.filter_offset = 0;
395 	fc_params.output_offset = FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET;
396 	fc_params.activation.min = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN;
397 	fc_params.activation.max = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX;
398 
399 	quant_params.multiplier = FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER;
400 	quant_params.shift = FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT;
401 
402 	int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
403 
404 	ctx.buf = malloc(buf_size);
405 	ctx.size = buf_size;
406 	arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
407 						   &fc_params,
408 						   &quant_params,
409 						   &input_dims,
410 						   input_data,
411 						   &filter_dims,
412 						   kernel_data,
413 						   &bias_dims,
414 						   bias_data,
415 						   &output_dims,
416 						   output);
417 
418 	free(ctx.buf);
419 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
420 	zassert_mem_equal(fully_connected_mve_0_output_ref, output, sizeof(output), "");
421 }
422 
423 #define MAXPOOLING_2_OUT_CH		5
424 #define MAXPOOLING_2_IN_CH		5
425 #define MAXPOOLING_2_INPUT_W		12
426 #define MAXPOOLING_2_INPUT_H		1
427 #define MAXPOOLING_2_DST_SIZE		60
428 #define MAXPOOLING_2_INPUT_SIZE		60
429 #define MAXPOOLING_2_OUT_ACTIVATION_MIN -128
430 #define MAXPOOLING_2_OUT_ACTIVATION_MAX 127
431 #define MAXPOOLING_2_INPUT_BATCHES	1
432 #define MAXPOOLING_2_FILTER_X		3
433 #define MAXPOOLING_2_FILTER_Y		1
434 #define MAXPOOLING_2_STRIDE_X		1
435 #define MAXPOOLING_2_STRIDE_Y		2
436 #define MAXPOOLING_2_PAD_X		1
437 #define MAXPOOLING_2_PAD_Y		0
438 #define MAXPOOLING_2_OUTPUT_W		12
439 #define MAXPOOLING_2_OUTPUT_H		1
440 
441 const int8_t maxpooling_2_input[60] = {
442 	75,  -52,  -42,	 -30, 56,  64,	 106, -36, 120, -3,  34,   -105, 69,   75,  -39,
443 	15,  93,   -71,	 39,  34,  -11,	 65,  22,  59,	106, 105,  45,	 -116, -75, 123,
444 	-65, 75,   -61,	 13,  -25, -123, 59,  110, -65, 86,  -108, -107, -17,  38,  27,
445 	-1,  -115, -123, 75,  -75, 68,	 52,  12,  -35, 116, -68,  22,	 15,   76,  -81};
446 
447 const int8_t maxpooling_2_output_ref[60] = {
448 	75,  106, -36, 120, 56,	 75,  106, 69,	120, 56,  64,  106, 69,	 120, 34,
449 	34,  93,  69,  75,  106, 105, 93,  22,	59,  123, 105, 75,  22,	 59,  123,
450 	105, 75,  110, 13,  123, -65, 75,  110, 38,  86,  -1,  59,  110, 75,  86,
451 	68,  52,  12,  75,  116, 68,  52,  15,	76,  116, 68,  52,  15,	 76,  116};
452 
ZTEST(cmsis_nn,test_max_pool)453 ZTEST(cmsis_nn, test_max_pool)
454 {
455 	int8_t output[MAXPOOLING_2_DST_SIZE] = {0};
456 
457 	cmsis_nn_context ctx;
458 	cmsis_nn_pool_params pool_params;
459 	cmsis_nn_dims input_dims;
460 	cmsis_nn_dims filter_dims;
461 	cmsis_nn_dims output_dims;
462 
463 	const int8_t *input_data = maxpooling_2_input;
464 
465 	input_dims.n = MAXPOOLING_2_INPUT_BATCHES;
466 	input_dims.w = MAXPOOLING_2_INPUT_W;
467 	input_dims.h = MAXPOOLING_2_INPUT_H;
468 	input_dims.c = MAXPOOLING_2_IN_CH;
469 	filter_dims.w = MAXPOOLING_2_FILTER_X;
470 	filter_dims.h = MAXPOOLING_2_FILTER_Y;
471 	output_dims.w = MAXPOOLING_2_OUTPUT_W;
472 	output_dims.h = MAXPOOLING_2_OUTPUT_H;
473 	output_dims.c = MAXPOOLING_2_OUT_CH;
474 
475 	pool_params.padding.w = MAXPOOLING_2_PAD_X;
476 	pool_params.padding.h = MAXPOOLING_2_PAD_Y;
477 	pool_params.stride.w = MAXPOOLING_2_STRIDE_X;
478 	pool_params.stride.h = MAXPOOLING_2_STRIDE_Y;
479 
480 	pool_params.activation.min = MAXPOOLING_2_OUT_ACTIVATION_MIN;
481 	pool_params.activation.max = MAXPOOLING_2_OUT_ACTIVATION_MAX;
482 
483 	for (int i = 0; i < REPEAT_NUM; i++) {
484 		arm_cmsis_nn_status result =
485 			arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
486 					&output_dims, output);
487 
488 		zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
489 		zassert_mem_equal(maxpooling_2_output_ref, output, sizeof(output), "");
490 	}
491 }
492 
493 #define SOFTMAX_NUM_ROWS	 2
494 #define SOFTMAX_ROW_SIZE	 5
495 #define SOFTMAX_INPUT_MULT	 1077952640
496 #define SOFTMAX_INPUT_LEFT_SHIFT 19
497 #define SOFTMAX_DIFF_MIN	 -3968
498 #define SOFTMAX_DST_SIZE	 10
499 
500 const int8_t softmax_input[10] = {101, 49, 6, -34, -75, -79, -38, 120, -55, 115};
501 
502 const int8_t softmax_output_ref[10] = {-57, -70, -79, -86, -92, -94, -88, -54, -91, -56};
503 
ZTEST(cmsis_nn,test_softmax)504 ZTEST(cmsis_nn, test_softmax)
505 {
506 	const int32_t num_rows = SOFTMAX_NUM_ROWS;
507 	const int32_t row_size = SOFTMAX_ROW_SIZE;
508 	const int32_t mult = SOFTMAX_INPUT_MULT;
509 	const int32_t shift = SOFTMAX_INPUT_LEFT_SHIFT;
510 	const int32_t diff_min = SOFTMAX_DIFF_MIN;
511 	const int8_t *input_data = softmax_input;
512 	int8_t output[SOFTMAX_DST_SIZE];
513 
514 	for (int i = 0; i < REPEAT_NUM; i++) {
515 		arm_softmax_s8(input_data, num_rows, row_size, mult, shift, diff_min, output);
516 		zassert_mem_equal(softmax_output_ref, output, sizeof(output), "");
517 	}
518 }
519 
520 #define SVDF_2_MULTIPLIER_IN	  1717987072
521 #define SVDF_2_MULTIPLIER_OUT	  1099511552
522 #define SVDF_2_SHIFT_1		  -3
523 #define SVDF_2_SHIFT_2		  -11
524 #define SVDF_2_IN_ACTIVATION_MIN  -32768
525 #define SVDF_2_IN_ACTIVATION_MAX  32767
526 #define SVDF_2_RANK		  2
527 #define SVDF_2_FEATURE_BATCHES	  10
528 #define SVDF_2_TIME_BATCHES	  2
529 #define SVDF_2_INPUT_SIZE	  7
530 #define SVDF_2_DST_SIZE		  15
531 #define SVDF_2_OUT_ACTIVATION_MIN -128
532 #define SVDF_2_OUT_ACTIVATION_MAX 127
533 #define SVDF_2_INPUT_BATCHES	  3
534 #define SVDF_2_INPUT_OFFSET	  0
535 #define SVDF_2_OUTPUT_OFFSET	  0
536 
537 const int32_t svdf_2_biases[5] = {0, 0, 0, 0, 0};
538 
539 const int16_t svdf_2_state[60] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
540 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
541 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
542 
543 const int8_t svdf_2_weights_feature[70] = {
544 	27,   82,  -108, -127, 85,  3,	 -51, 32,  110, -6,  -14, -16,	31,  101,
545 	-122, 19,  76,	 74,   -80, 12,	 -22, -17, 10,	-28, 55,  109,	2,   -107,
546 	-4,   72,  -65,	 -59,  36,  -69, 105, -97, 25,	38,  110, -121, -88, -126,
547 	-14,  16,  -88,	 -66,  3,   -93, 69,  -64, 44,	103, 95,  -95,	68,  -46,
548 	106,  -31, -63,	 23,   -38, 36,	 -95, -43, 93,	77,  91,  -26,	33,  59};
549 
550 const int16_t svdf_2_weights_time[20] = {-31, -88, -10, -72, -119, -6, -70, 63,	 -10, 93,
551 					 5,   42,  -6,	22,  6,	   51, 37,  -38, 5,   117};
552 
553 const int8_t svdf_2_input_sequence[42] = {
554 	29,   81,   -38, 17,  -116, 43,	 119,  -127, 74,  115, 9,    118,  7,	 -56,
555 	-53,  -14,  -98, 60,  -128, 10,	 28,   -18,  12,  -28, -126, 87,   -115, -44,
556 	-123, -109, -59, -87, -69,  121, -128, -95,  -70, 2,   81,   -119, 84,	 -122};
557 
558 const int8_t svdf_2_output_ref[15] = {-53, 45,	27, -24, -53, 26, -82, -38,
559 				      11,  -85, 94, -16, -32, 31, 4};
560 
check_null_bias(const int32_t * bias,int32_t size)561 static bool check_null_bias(const int32_t *bias, int32_t size)
562 {
563 	bool null_bias = true;
564 
565 	for (int i = 0; i < size; i++) {
566 		if (bias[i] != 0) {
567 			null_bias = false;
568 			break;
569 		}
570 	}
571 	return null_bias;
572 }
573 
ZTEST(cmsis_nn,test_svdf)574 ZTEST(cmsis_nn, test_svdf)
575 {
576 	cmsis_nn_context input_ctx;
577 	cmsis_nn_context output_ctx;
578 	cmsis_nn_svdf_params svdf_2_params;
579 	cmsis_nn_dims input_dims;
580 	cmsis_nn_dims weights_feature_dims;
581 	cmsis_nn_dims weights_time_dims;
582 	cmsis_nn_dims state_dims;
583 	cmsis_nn_dims output_dims;
584 	cmsis_nn_dims bias_dims;
585 	cmsis_nn_per_tensor_quant_params input_quant_params;
586 	cmsis_nn_per_tensor_quant_params output_quant_params;
587 	int8_t output_data[SVDF_2_DST_SIZE];
588 
589 	const int8_t *weights_feature_data = svdf_2_weights_feature;
590 	const int16_t *weights_time_data = svdf_2_weights_time;
591 
592 	input_dims.n = SVDF_2_INPUT_BATCHES;
593 	input_dims.h = SVDF_2_INPUT_SIZE;
594 	weights_feature_dims.n = SVDF_2_FEATURE_BATCHES;
595 	weights_time_dims.h = SVDF_2_TIME_BATCHES;
596 
597 	input_quant_params.multiplier = SVDF_2_MULTIPLIER_IN;
598 	input_quant_params.shift = SVDF_2_SHIFT_1;
599 	output_quant_params.multiplier = SVDF_2_MULTIPLIER_OUT;
600 	output_quant_params.shift = SVDF_2_SHIFT_2;
601 
602 	svdf_2_params.input_activation.min = SVDF_2_IN_ACTIVATION_MIN;
603 	svdf_2_params.input_activation.max = SVDF_2_IN_ACTIVATION_MAX;
604 	svdf_2_params.output_activation.min = SVDF_2_OUT_ACTIVATION_MIN;
605 	svdf_2_params.output_activation.max = SVDF_2_OUT_ACTIVATION_MAX;
606 	svdf_2_params.input_offset = SVDF_2_INPUT_OFFSET;
607 	svdf_2_params.output_offset = SVDF_2_OUTPUT_OFFSET;
608 	svdf_2_params.rank = SVDF_2_RANK;
609 
610 	const int input_round_size = SVDF_2_INPUT_BATCHES * SVDF_2_INPUT_SIZE;
611 	const int number_inputs = sizeof(svdf_2_input_sequence) / input_round_size;
612 	const int32_t number_units = SVDF_2_FEATURE_BATCHES / SVDF_2_RANK;
613 	const int scratch_size = SVDF_2_INPUT_BATCHES * SVDF_2_FEATURE_BATCHES * sizeof(int32_t);
614 	const int scratch_size_out = SVDF_2_INPUT_BATCHES * number_units * sizeof(int32_t);
615 
616 	input_ctx.buf = malloc(scratch_size);
617 	output_ctx.buf = malloc(scratch_size_out);
618 
619 	int8_t *input_data = malloc(input_round_size);
620 	int16_t *state_data = malloc(sizeof(svdf_2_state));
621 	const bool null_bias = check_null_bias(svdf_2_biases,
622 					       SVDF_2_DST_SIZE / SVDF_2_INPUT_BATCHES);
623 
624 	for (int i = 0; i < REPEAT_NUM; i++) {
625 		memcpy(state_data, svdf_2_state, sizeof(svdf_2_state));
626 		for (int j = 0; j < number_inputs; j++) {
627 			memcpy(input_data, svdf_2_input_sequence + j * input_round_size,
628 			       input_round_size);
629 			arm_cmsis_nn_status result = arm_svdf_state_s16_s8(&input_ctx,
630 							&output_ctx,
631 							&svdf_2_params,
632 							&input_quant_params,
633 							&output_quant_params,
634 							&input_dims,
635 							input_data,
636 							&state_dims,
637 							state_data,
638 							&weights_feature_dims,
639 							weights_feature_data,
640 							&weights_time_dims,
641 							weights_time_data,
642 							&bias_dims,
643 							null_bias == true ? NULL : svdf_2_biases,
644 							&output_dims,
645 							output_data);
646 			zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
647 		}
648 
649 		zassert_mem_equal(svdf_2_output_ref, output_data, sizeof(output_data), "");
650 	}
651 	free(state_data);
652 	free(input_data);
653 	free(input_ctx.buf);
654 	free(output_ctx.buf);
655 }
656 
657 ZTEST_SUITE(cmsis_nn, NULL, NULL, NULL, NULL, NULL);
658