1 /*
2  * Copyright (c) 2021, Commonwealth Scientific and Industrial Research
3  * Organisation (CSIRO) ABN 41 687 119 230.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * This is not exhaustive functional testing of the CMSIS-NN library.
8  *
9  * Individual tests have been pulled from CMSIS/NN/Tests/UnitTest to
10  * validate the integration of CMSIS-NN and Zephyr
11  */
12 
13 #include <zephyr/ztest.h>
14 #include <zephyr/kernel.h>
15 #include <stdlib.h>
16 
17 #include "arm_nnfunctions.h"
18 
19 #define REPEAT_NUM 3
20 
21 #define AVGPOOLING_2_OUT_CH		5
22 #define AVGPOOLING_2_IN_CH		5
23 #define AVGPOOLING_2_INPUT_W		12
24 #define AVGPOOLING_2_INPUT_H		1
25 #define AVGPOOLING_2_DST_SIZE		60
26 #define AVGPOOLING_2_INPUT_SIZE		60
27 #define AVGPOOLING_2_OUT_ACTIVATION_MIN -128
28 #define AVGPOOLING_2_OUT_ACTIVATION_MAX 127
29 #define AVGPOOLING_2_INPUT_BATCHES	1
30 #define AVGPOOLING_2_FILTER_X		3
31 #define AVGPOOLING_2_FILTER_Y		1
32 #define AVGPOOLING_2_STRIDE_X		1
33 #define AVGPOOLING_2_STRIDE_Y		2
34 #define AVGPOOLING_2_PAD_X		1
35 #define AVGPOOLING_2_PAD_Y		0
36 #define AVGPOOLING_2_OUTPUT_W		12
37 #define AVGPOOLING_2_OUTPUT_H		1
38 
39 const int8_t avgpooling_2_input[60] = {
40 	-82, -104, 10,	-28, -52, -51, -66, 52,	 124, -74, -21,	 4,  37,   -7,	-33,
41 	102, 110,  24,	52,  121, 13,  -55, -79, -92, -35, -103, 86, 95,   46,	32,
42 	-24, -123, 120, 29,  -77, -97, -69, -68, 58,  38,  3,	 3,  79,   -47, 112,
43 	-52, -113, -46, 107, 68,  83,  -70, 91,	 14,  113, 74,	 73, -103, -98, 25};
44 
45 const int8_t avgpooling_2_output_ref[60] = {
46 	-67, -85, 31, 48,  -63, -51, -55, 33,  30, -53, 10,  16,  38,  56,  5,
47 	31,  20,  -6, -16, 18,	4,   47,  13,  2,  39,	-38, -31, 45,  -6,  -27,
48 	-75, -35, 49, 44,  -2,	-39, -63, 44,  13, 24,	-49, -60, -12, 39,  73,
49 	11,  -60, 41, 25,  98,	35,  -37, -19, 8,  69,	79,  2,	  -6,  -42, 69};
50 
ZTEST(cmsis_nn,test_avgpool)51 ZTEST(cmsis_nn, test_avgpool)
52 {
53 	int8_t output[AVGPOOLING_2_DST_SIZE] = {0};
54 
55 	cmsis_nn_context ctx;
56 	cmsis_nn_pool_params pool_params;
57 	cmsis_nn_dims input_dims;
58 	cmsis_nn_dims filter_dims;
59 	cmsis_nn_dims output_dims;
60 
61 	input_dims.n = AVGPOOLING_2_INPUT_BATCHES;
62 	input_dims.w = AVGPOOLING_2_INPUT_W;
63 	input_dims.h = AVGPOOLING_2_INPUT_H;
64 	input_dims.c = AVGPOOLING_2_IN_CH;
65 	filter_dims.w = AVGPOOLING_2_FILTER_X;
66 	filter_dims.h = AVGPOOLING_2_FILTER_Y;
67 	output_dims.w = AVGPOOLING_2_OUTPUT_W;
68 	output_dims.h = AVGPOOLING_2_OUTPUT_H;
69 	output_dims.c = AVGPOOLING_2_OUT_CH;
70 
71 	pool_params.padding.w = AVGPOOLING_2_PAD_X;
72 	pool_params.padding.h = AVGPOOLING_2_PAD_Y;
73 	pool_params.stride.w = AVGPOOLING_2_STRIDE_X;
74 	pool_params.stride.h = AVGPOOLING_2_STRIDE_Y;
75 
76 	pool_params.activation.min = AVGPOOLING_2_OUT_ACTIVATION_MIN;
77 	pool_params.activation.max = AVGPOOLING_2_OUT_ACTIVATION_MAX;
78 
79 	ctx.size = arm_avgpool_s8_get_buffer_size(AVGPOOLING_2_OUTPUT_W, AVGPOOLING_2_IN_CH);
80 	ctx.buf = malloc(ctx.size);
81 
82 	arm_cmsis_nn_status result = arm_avgpool_s8(&ctx,
83 						   &pool_params,
84 						   &input_dims,
85 						   avgpooling_2_input,
86 						   &filter_dims,
87 						   &output_dims,
88 						   output);
89 
90 	free(ctx.buf);
91 
92 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
93 	zassert_mem_equal(avgpooling_2_output_ref, output, sizeof(output), "");
94 }
95 
96 #define CONV_4_OUT_CH		  3
97 #define CONV_4_IN_CH		  3
98 #define CONV_4_INPUT_W		  5
99 #define CONV_4_INPUT_H		  5
100 #define CONV_4_DST_SIZE		  36
101 #define CONV_4_INPUT_SIZE	  75
102 #define CONV_4_OUT_ACTIVATION_MIN -109
103 #define CONV_4_OUT_ACTIVATION_MAX 127
104 #define CONV_4_INPUT_BATCHES	  3
105 #define CONV_4_FILTER_X		  2
106 #define CONV_4_FILTER_Y		  3
107 #define CONV_4_STRIDE_X		  2
108 #define CONV_4_STRIDE_Y		  2
109 #define CONV_4_PAD_X		  0
110 #define CONV_4_PAD_Y		  0
111 #define CONV_4_OUTPUT_W		  2
112 #define CONV_4_OUTPUT_H		  2
113 #define CONV_4_INPUT_OFFSET	  128
114 #define CONV_4_OUTPUT_OFFSET	  -128
115 #define CONV_4_DILATION_X	  1
116 #define CONV_4_DILATION_Y	  1
117 
118 const int32_t conv_4_biases[3] = {13175, 9050, 18215};
119 
120 const int8_t conv_4_weights[54] = {
121 	-25, -83, -74,	105, 30,  118, -32, 127, 34,  127, -112, 39, -43, 104, 41,  -124, 115, 5,
122 	42,  -48, -119, 93,  17,  57,  41,  -41, -42, 23,  127,	 18, 70,  -99, 71,  67,	  83,  76,
123 	-50, 98,  66,	64,  127, -6,  -77, -48, -26, 45,  77,	 1,  81,  27,  124, -103, 37,  36};
124 
125 const int8_t conv_4_input[225] = {
126 	82,   120,  -97, -44,  -118, 73,   4,	 -84,  -53,  -122, -15,	 77,   83,  43,	  37,
127 	85,   -11,  103, 45,   -69,  -12,  -8,	 21,   6,    -68,  -83,	 -15,  -99, 90,	  -62,
128 	95,   62,   -38, -32,  -35,  -105, -53,	 70,   112,  14,   -4,	 -33,  -26, -93,  -98,
129 	22,   -5,   22,	 -104, 57,   -92,  30,	 -62,  0,    -43,  -82,	 60,   99,  -83,  32,
130 	94,   49,   10,	 112,  -71,  -27,  -91,	 -79,  52,   -92,  -71,	 86,   -79, -15,  -80,
131 	-74,  -4,   76,	 -119, 91,   -23,  -12,	 -111, -72,  26,   11,	 64,   116, 38,	  99,
132 	125,  17,   6,	 -4,   46,   119,  113,	 -116, -125, 80,   -57,	 122,  75,  119,  -117,
133 	87,   -121, -70, -75,  -127, 16,   -124, -110, 10,   71,   29,	 27,   37,  -24,  52,
134 	28,   -100, 86,	 -75,  117,  -31,  -115, -86,  -122, 121,  -96,	 -118, 32,  111,  25,
135 	-90,  -8,   110, 37,   35,   124,  -123, 94,   -122, -114, 37,	 85,   -36, 53,	  -40,
136 	73,   -99,  27,	 10,   37,   41,   64,	 -97,  -123, 75,   0,	 -107, -72, 58,	  -100,
137 	17,   77,   114, 120,  -83,  -96,  75,	 -12,  -27,  3,	   35,	 85,   4,   119,  -20,
138 	28,   99,   104, -78,  -51,  -82,  -92,	 -40,  -116, 35,   -107, 39,   9,   -120, -50,
139 	-102, -114, 25,	 -77,  25,   7,	   64,	 110,  80,   -93,  -20,	 34,   115, 75,	  37,
140 	47,   16,   6,	 -92,  -25,  37,   69,	 82,   -61,  -100, -85,	 -51,  6,   -95,  58
141 };
142 
143 const int32_t conv_4_output_mult[3] = {2039209398, 2005068758, 2023002003};
144 
145 const int32_t conv_4_output_shift[3] = {-9, -9, -9};
146 
147 const int8_t conv_4_output_ref[36] = {-5,   -39, -31, 20,  -37, -26, -109, -7,	-10, -51, -58, 48,
148 				      -100, -32, 24,  4,   69,	-38, -64,  65,	-34, 95,  -55, 39,
149 				      95,   -54, 27,  -49, 25,	-68, -109, -66, 72,  38,  -44, -40};
150 
ZTEST(cmsis_nn,test_convolve)151 ZTEST(cmsis_nn, test_convolve)
152 {
153 	int8_t output[CONV_4_DST_SIZE] = {0};
154 
155 	cmsis_nn_context ctx;
156 	cmsis_nn_conv_params conv_params;
157 	cmsis_nn_per_channel_quant_params quant_params;
158 	cmsis_nn_dims input_dims;
159 	cmsis_nn_dims filter_dims;
160 	cmsis_nn_dims bias_dims;
161 	cmsis_nn_dims output_dims;
162 
163 	const int32_t *bias_data = conv_4_biases;
164 	const int8_t *kernel_data = conv_4_weights;
165 	const int8_t *input_data = conv_4_input;
166 
167 	input_dims.n = CONV_4_INPUT_BATCHES;
168 	input_dims.w = CONV_4_INPUT_W;
169 	input_dims.h = CONV_4_INPUT_H;
170 	input_dims.c = CONV_4_IN_CH;
171 	filter_dims.w = CONV_4_FILTER_X;
172 	filter_dims.h = CONV_4_FILTER_Y;
173 	filter_dims.c = CONV_4_IN_CH;
174 	output_dims.w = CONV_4_OUTPUT_W;
175 	output_dims.h = CONV_4_OUTPUT_H;
176 	output_dims.c = CONV_4_OUT_CH;
177 
178 	conv_params.padding.w = CONV_4_PAD_X;
179 	conv_params.padding.h = CONV_4_PAD_Y;
180 	conv_params.stride.w = CONV_4_STRIDE_X;
181 	conv_params.stride.h = CONV_4_STRIDE_Y;
182 	conv_params.dilation.w = CONV_4_DILATION_X;
183 	conv_params.dilation.h = CONV_4_DILATION_Y;
184 
185 	conv_params.input_offset = CONV_4_INPUT_OFFSET;
186 	conv_params.output_offset = CONV_4_OUTPUT_OFFSET;
187 	conv_params.activation.min = CONV_4_OUT_ACTIVATION_MIN;
188 	conv_params.activation.max = CONV_4_OUT_ACTIVATION_MAX;
189 	quant_params.multiplier = (int32_t *)conv_4_output_mult;
190 	quant_params.shift = (int32_t *)conv_4_output_shift;
191 
192 	int32_t buf_size = arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims);
193 
194 	ctx.buf = malloc(buf_size);
195 	ctx.size = 0;
196 
197 	arm_cmsis_nn_status result = arm_convolve_s8(&ctx,
198 					    &conv_params,
199 					    &quant_params,
200 					    &input_dims,
201 					    input_data,
202 					    &filter_dims,
203 					    kernel_data,
204 					    &bias_dims,
205 					    bias_data,
206 					    &output_dims,
207 					    output);
208 
209 	free(ctx.buf);
210 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
211 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
212 
213 	buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
214 							   &filter_dims, &output_dims);
215 	ctx.buf = malloc(buf_size);
216 	ctx.size = 0;
217 
218 	result = arm_convolve_wrapper_s8(&ctx,
219 					 &conv_params,
220 					 &quant_params,
221 					 &input_dims,
222 					 input_data,
223 					 &filter_dims,
224 					 kernel_data,
225 					 &bias_dims,
226 					 bias_data,
227 					 &output_dims,
228 					 output);
229 
230 	free(ctx.buf);
231 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
232 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
233 }
234 
235 #define STRIDE2PAD1_OUT_CH	       1
236 #define STRIDE2PAD1_IN_CH	       1
237 #define STRIDE2PAD1_INPUT_W	       7
238 #define STRIDE2PAD1_INPUT_H	       7
239 #define STRIDE2PAD1_DST_SIZE	       16
240 #define STRIDE2PAD1_INPUT_SIZE	       49
241 #define STRIDE2PAD1_OUT_ACTIVATION_MIN -128
242 #define STRIDE2PAD1_OUT_ACTIVATION_MAX 127
243 #define STRIDE2PAD1_INPUT_BATCHES      1
244 #define STRIDE2PAD1_FILTER_X	       3
245 #define STRIDE2PAD1_FILTER_Y	       3
246 #define STRIDE2PAD1_STRIDE_X	       2
247 #define STRIDE2PAD1_STRIDE_Y	       2
248 #define STRIDE2PAD1_PAD_X	       1
249 #define STRIDE2PAD1_PAD_Y	       1
250 #define STRIDE2PAD1_OUTPUT_W	       4
251 #define STRIDE2PAD1_OUTPUT_H	       4
252 #define STRIDE2PAD1_INPUT_OFFSET       128
253 #define STRIDE2PAD1_OUTPUT_OFFSET      -20
254 #define STRIDE2PAD1_DILATION_X	       1
255 #define STRIDE2PAD1_DILATION_Y	       1
256 
257 const int32_t stride2pad1_biases[1] = {-9794};
258 
259 const int8_t stride2pad1_weights[9] = {-54, 57, -19, -127, 87, 70, 74, -110, 66};
260 
261 const int8_t stride2pad1_input[49] = {
262 	-91, -30, -57, -76, 32,	 -13, 14,   -96, 108, -4,  41,	48,  107, -68,	-101, 30,  95,
263 	95,  91,  -66, -80, 114, -49, 7,    -67, -35, -1,  -88, -77, -56, -103, 5,    -39, -118,
264 	-24, -32, 67,  11,  38,	 -16, -124, 44,	 -46, -92, -24, 108, 80,  -29,	-3};
265 
266 const int32_t stride2pad1_output_mult[1] = {2033801520};
267 
268 const int32_t stride2pad1_output_shift[1] = {-8};
269 
270 const int8_t stride2pad1_output_ref[16] = {26, -11, 33,	 -25,  -96, -52, -78, -86,
271 					   33, -2,  -88, -113, -14, 0,	 -84, -27};
272 
ZTEST(cmsis_nn,test_depthwise_convolve)273 ZTEST(cmsis_nn, test_depthwise_convolve)
274 {
275 	int8_t output[STRIDE2PAD1_DST_SIZE] = {0};
276 
277 	cmsis_nn_context ctx;
278 	cmsis_nn_dw_conv_params dw_conv_params;
279 	cmsis_nn_per_channel_quant_params quant_params;
280 	cmsis_nn_dims input_dims;
281 	cmsis_nn_dims filter_dims;
282 	cmsis_nn_dims bias_dims = {0};
283 	cmsis_nn_dims output_dims;
284 
285 	const int32_t *bias_data = stride2pad1_biases;
286 	const int8_t *kernel_data = stride2pad1_weights;
287 	const int8_t *input_data = stride2pad1_input;
288 
289 	input_dims.n = STRIDE2PAD1_INPUT_BATCHES;
290 	input_dims.w = STRIDE2PAD1_INPUT_W;
291 	input_dims.h = STRIDE2PAD1_INPUT_H;
292 	input_dims.c = STRIDE2PAD1_IN_CH;
293 	filter_dims.w = STRIDE2PAD1_FILTER_X;
294 	filter_dims.h = STRIDE2PAD1_FILTER_Y;
295 	output_dims.w = STRIDE2PAD1_OUTPUT_W;
296 	output_dims.h = STRIDE2PAD1_OUTPUT_H;
297 	output_dims.c = STRIDE2PAD1_OUT_CH;
298 
299 	dw_conv_params.padding.w = STRIDE2PAD1_PAD_X;
300 	dw_conv_params.padding.h = STRIDE2PAD1_PAD_Y;
301 	dw_conv_params.stride.w = STRIDE2PAD1_STRIDE_X;
302 	dw_conv_params.stride.h = STRIDE2PAD1_STRIDE_Y;
303 	dw_conv_params.dilation.w = STRIDE2PAD1_DILATION_X;
304 	dw_conv_params.dilation.h = STRIDE2PAD1_DILATION_Y;
305 
306 	dw_conv_params.ch_mult = 1;
307 
308 	dw_conv_params.input_offset = STRIDE2PAD1_INPUT_OFFSET;
309 	dw_conv_params.output_offset = STRIDE2PAD1_OUTPUT_OFFSET;
310 	dw_conv_params.activation.min = STRIDE2PAD1_OUT_ACTIVATION_MIN;
311 	dw_conv_params.activation.max = STRIDE2PAD1_OUT_ACTIVATION_MAX;
312 	quant_params.multiplier = (int32_t *)stride2pad1_output_mult;
313 	quant_params.shift = (int32_t *)stride2pad1_output_shift;
314 
315 	ctx.buf = NULL;
316 	ctx.size = 0;
317 
318 	arm_cmsis_nn_status result = arm_depthwise_conv_s8(&ctx,
319 						  &dw_conv_params,
320 						  &quant_params,
321 						  &input_dims,
322 						  input_data,
323 						  &filter_dims,
324 						  kernel_data,
325 						  &bias_dims,
326 						  bias_data,
327 						  &output_dims,
328 						  output);
329 
330 	free(ctx.buf);
331 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
332 	zassert_mem_equal(stride2pad1_output_ref, output, sizeof(output), "");
333 }
334 
335 #define FULLY_CONNECTED_MVE_0_OUT_CH		 9
336 #define FULLY_CONNECTED_MVE_0_IN_CH		 16
337 #define FULLY_CONNECTED_MVE_0_INPUT_W		 1
338 #define FULLY_CONNECTED_MVE_0_INPUT_H		 1
339 #define FULLY_CONNECTED_MVE_0_DST_SIZE		 9
340 #define FULLY_CONNECTED_MVE_0_INPUT_SIZE	 16
341 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN -128
342 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX 127
343 #define FULLY_CONNECTED_MVE_0_INPUT_BATCHES	 1
344 #define FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER	 1244038257
345 #define FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT	 -9
346 #define FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH 16
347 #define FULLY_CONNECTED_MVE_0_INPUT_OFFSET	 128
348 #define FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET	 -26
349 
350 const int32_t fully_connected_mve_0_biases[9] = {11295, -30752, -3196, 10489, -5120,
351 						 18598, 27393,	29746, 22967};
352 
353 const int8_t fully_connected_mve_0_input[16] = {-43, 68,  79,	-12, -119, -56, -102, -46,
354 						107, -65, -109, -7,  92,   -99, -80,  -29};
355 
356 const int8_t fully_connected_mve_0_output_ref[9] = {-9, -3, 26, 8, 3, -88, 75, 34, 5};
357 
358 const int8_t fully_connected_mve_0_weights[144] = {
359 	37,  -46,  75,	 -33,  -52, -82,  -94,	64,   71,  65,	 64,  16,   -66, -5,   -65,  -44,
360 	82,  42,   84,	 105,  18,  79,	  -103, -75,  -95, 65,	 87,  103,  43,	 -25,  -66,  75,
361 	125, 40,   -34,	 24,   9,   -79,  4,	73,   98,  -75,	 42,  81,   18,	 -58,  -119, 92,
362 	0,   -72,  48,	 23,   -69, 11,	  -95,	-103, 66,  117,	 107, -96,  114, -29,  75,   -93,
363 	118, 66,   -19,	 83,   -14, 86,	  -110, 44,   37,  -9,	 17,  -107, 50,	 -116, -116, -27,
364 	-84, -126, -108, -127, -71, 8,	  81,	108,  -61, 126,	 69,  -45,  37,	 -78,  -102, -55,
365 	116, 112,  -111, -89,  -57, 82,	  -47,	22,   125, -84,	 97,  -9,   88,	 74,   -15,  118,
366 	-95, 112,  89,	 44,   -17, -112, -71,	-94,  1,   -117, 112, -92,  52,	 57,   -22,  80,
367 	-60, 95,   -106, -1,   -27, 105,  6,	123,  6,   96,	 126, -65,  -29, 103,  19,   -45};
368 
ZTEST(cmsis_nn,test_fully_connected)369 ZTEST(cmsis_nn, test_fully_connected)
370 {
371 	int8_t output[FULLY_CONNECTED_MVE_0_DST_SIZE] = {0};
372 
373 	cmsis_nn_context ctx;
374 	cmsis_nn_fc_params fc_params;
375 	cmsis_nn_per_tensor_quant_params quant_params;
376 	cmsis_nn_dims input_dims;
377 	cmsis_nn_dims filter_dims;
378 	cmsis_nn_dims bias_dims;
379 	cmsis_nn_dims output_dims;
380 
381 	const int32_t *bias_data = fully_connected_mve_0_biases;
382 	const int8_t *kernel_data = fully_connected_mve_0_weights;
383 	const int8_t *input_data = fully_connected_mve_0_input;
384 
385 	input_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
386 	input_dims.w = FULLY_CONNECTED_MVE_0_INPUT_W;
387 	input_dims.h = FULLY_CONNECTED_MVE_0_INPUT_H;
388 	input_dims.c = FULLY_CONNECTED_MVE_0_IN_CH;
389 	filter_dims.n = FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH;
390 	filter_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
391 	output_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
392 	output_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
393 
394 	fc_params.input_offset = FULLY_CONNECTED_MVE_0_INPUT_OFFSET;
395 	fc_params.filter_offset = 0;
396 	fc_params.output_offset = FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET;
397 	fc_params.activation.min = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN;
398 	fc_params.activation.max = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX;
399 
400 	quant_params.multiplier = FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER;
401 	quant_params.shift = FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT;
402 
403 	int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
404 
405 	ctx.buf = malloc(buf_size);
406 	ctx.size = buf_size;
407 	arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
408 						   &fc_params,
409 						   &quant_params,
410 						   &input_dims,
411 						   input_data,
412 						   &filter_dims,
413 						   kernel_data,
414 						   &bias_dims,
415 						   bias_data,
416 						   &output_dims,
417 						   output);
418 
419 	free(ctx.buf);
420 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
421 	zassert_mem_equal(fully_connected_mve_0_output_ref, output, sizeof(output), "");
422 }
423 
424 #define MAXPOOLING_2_OUT_CH		5
425 #define MAXPOOLING_2_IN_CH		5
426 #define MAXPOOLING_2_INPUT_W		12
427 #define MAXPOOLING_2_INPUT_H		1
428 #define MAXPOOLING_2_DST_SIZE		60
429 #define MAXPOOLING_2_INPUT_SIZE		60
430 #define MAXPOOLING_2_OUT_ACTIVATION_MIN -128
431 #define MAXPOOLING_2_OUT_ACTIVATION_MAX 127
432 #define MAXPOOLING_2_INPUT_BATCHES	1
433 #define MAXPOOLING_2_FILTER_X		3
434 #define MAXPOOLING_2_FILTER_Y		1
435 #define MAXPOOLING_2_STRIDE_X		1
436 #define MAXPOOLING_2_STRIDE_Y		2
437 #define MAXPOOLING_2_PAD_X		1
438 #define MAXPOOLING_2_PAD_Y		0
439 #define MAXPOOLING_2_OUTPUT_W		12
440 #define MAXPOOLING_2_OUTPUT_H		1
441 
442 const int8_t maxpooling_2_input[60] = {
443 	75,  -52,  -42,	 -30, 56,  64,	 106, -36, 120, -3,  34,   -105, 69,   75,  -39,
444 	15,  93,   -71,	 39,  34,  -11,	 65,  22,  59,	106, 105,  45,	 -116, -75, 123,
445 	-65, 75,   -61,	 13,  -25, -123, 59,  110, -65, 86,  -108, -107, -17,  38,  27,
446 	-1,  -115, -123, 75,  -75, 68,	 52,  12,  -35, 116, -68,  22,	 15,   76,  -81};
447 
448 const int8_t maxpooling_2_output_ref[60] = {
449 	75,  106, -36, 120, 56,	 75,  106, 69,	120, 56,  64,  106, 69,	 120, 34,
450 	34,  93,  69,  75,  106, 105, 93,  22,	59,  123, 105, 75,  22,	 59,  123,
451 	105, 75,  110, 13,  123, -65, 75,  110, 38,  86,  -1,  59,  110, 75,  86,
452 	68,  52,  12,  75,  116, 68,  52,  15,	76,  116, 68,  52,  15,	 76,  116};
453 
ZTEST(cmsis_nn,test_max_pool)454 ZTEST(cmsis_nn, test_max_pool)
455 {
456 	int8_t output[MAXPOOLING_2_DST_SIZE] = {0};
457 
458 	cmsis_nn_context ctx;
459 	cmsis_nn_pool_params pool_params;
460 	cmsis_nn_dims input_dims;
461 	cmsis_nn_dims filter_dims;
462 	cmsis_nn_dims output_dims;
463 
464 	const int8_t *input_data = maxpooling_2_input;
465 
466 	input_dims.n = MAXPOOLING_2_INPUT_BATCHES;
467 	input_dims.w = MAXPOOLING_2_INPUT_W;
468 	input_dims.h = MAXPOOLING_2_INPUT_H;
469 	input_dims.c = MAXPOOLING_2_IN_CH;
470 	filter_dims.w = MAXPOOLING_2_FILTER_X;
471 	filter_dims.h = MAXPOOLING_2_FILTER_Y;
472 	output_dims.w = MAXPOOLING_2_OUTPUT_W;
473 	output_dims.h = MAXPOOLING_2_OUTPUT_H;
474 	output_dims.c = MAXPOOLING_2_OUT_CH;
475 
476 	pool_params.padding.w = MAXPOOLING_2_PAD_X;
477 	pool_params.padding.h = MAXPOOLING_2_PAD_Y;
478 	pool_params.stride.w = MAXPOOLING_2_STRIDE_X;
479 	pool_params.stride.h = MAXPOOLING_2_STRIDE_Y;
480 
481 	pool_params.activation.min = MAXPOOLING_2_OUT_ACTIVATION_MIN;
482 	pool_params.activation.max = MAXPOOLING_2_OUT_ACTIVATION_MAX;
483 
484 	for (int i = 0; i < REPEAT_NUM; i++) {
485 		arm_cmsis_nn_status result =
486 			arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
487 					&output_dims, output);
488 
489 		zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
490 		zassert_mem_equal(maxpooling_2_output_ref, output, sizeof(output), "");
491 	}
492 }
493 
494 #define SOFTMAX_NUM_ROWS	 2
495 #define SOFTMAX_ROW_SIZE	 5
496 #define SOFTMAX_INPUT_MULT	 1077952640
497 #define SOFTMAX_INPUT_LEFT_SHIFT 19
498 #define SOFTMAX_DIFF_MIN	 -3968
499 #define SOFTMAX_DST_SIZE	 10
500 
501 const int8_t softmax_input[10] = {101, 49, 6, -34, -75, -79, -38, 120, -55, 115};
502 
503 const int8_t softmax_output_ref[10] = {-57, -70, -79, -86, -92, -94, -88, -54, -91, -56};
504 
ZTEST(cmsis_nn,test_softmax)505 ZTEST(cmsis_nn, test_softmax)
506 {
507 	const int32_t num_rows = SOFTMAX_NUM_ROWS;
508 	const int32_t row_size = SOFTMAX_ROW_SIZE;
509 	const int32_t mult = SOFTMAX_INPUT_MULT;
510 	const int32_t shift = SOFTMAX_INPUT_LEFT_SHIFT;
511 	const int32_t diff_min = SOFTMAX_DIFF_MIN;
512 	const int8_t *input_data = softmax_input;
513 	int8_t output[SOFTMAX_DST_SIZE];
514 
515 	for (int i = 0; i < REPEAT_NUM; i++) {
516 		arm_softmax_s8(input_data, num_rows, row_size, mult, shift, diff_min, output);
517 		zassert_mem_equal(softmax_output_ref, output, sizeof(output), "");
518 	}
519 }
520 
521 #define SVDF_2_MULTIPLIER_IN	  1717987072
522 #define SVDF_2_MULTIPLIER_OUT	  1099511552
523 #define SVDF_2_SHIFT_1		  -3
524 #define SVDF_2_SHIFT_2		  -11
525 #define SVDF_2_IN_ACTIVATION_MIN  -32768
526 #define SVDF_2_IN_ACTIVATION_MAX  32767
527 #define SVDF_2_RANK		  2
528 #define SVDF_2_FEATURE_BATCHES	  10
529 #define SVDF_2_TIME_BATCHES	  2
530 #define SVDF_2_INPUT_SIZE	  7
531 #define SVDF_2_DST_SIZE		  15
532 #define SVDF_2_OUT_ACTIVATION_MIN -128
533 #define SVDF_2_OUT_ACTIVATION_MAX 127
534 #define SVDF_2_INPUT_BATCHES	  3
535 #define SVDF_2_INPUT_OFFSET	  0
536 #define SVDF_2_OUTPUT_OFFSET	  0
537 
538 const int32_t svdf_2_biases[5] = {0, 0, 0, 0, 0};
539 
540 const int16_t svdf_2_state[60] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
541 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
542 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
543 
544 const int8_t svdf_2_weights_feature[70] = {
545 	27,   82,  -108, -127, 85,  3,	 -51, 32,  110, -6,  -14, -16,	31,  101,
546 	-122, 19,  76,	 74,   -80, 12,	 -22, -17, 10,	-28, 55,  109,	2,   -107,
547 	-4,   72,  -65,	 -59,  36,  -69, 105, -97, 25,	38,  110, -121, -88, -126,
548 	-14,  16,  -88,	 -66,  3,   -93, 69,  -64, 44,	103, 95,  -95,	68,  -46,
549 	106,  -31, -63,	 23,   -38, 36,	 -95, -43, 93,	77,  91,  -26,	33,  59};
550 
551 const int16_t svdf_2_weights_time[20] = {-31, -88, -10, -72, -119, -6, -70, 63,	 -10, 93,
552 					 5,   42,  -6,	22,  6,	   51, 37,  -38, 5,   117};
553 
554 const int8_t svdf_2_input_sequence[42] = {
555 	29,   81,   -38, 17,  -116, 43,	 119,  -127, 74,  115, 9,    118,  7,	 -56,
556 	-53,  -14,  -98, 60,  -128, 10,	 28,   -18,  12,  -28, -126, 87,   -115, -44,
557 	-123, -109, -59, -87, -69,  121, -128, -95,  -70, 2,   81,   -119, 84,	 -122};
558 
559 const int8_t svdf_2_output_ref[15] = {-53, 45,	27, -24, -53, 26, -82, -38,
560 				      11,  -85, 94, -16, -32, 31, 4};
561 
check_null_bias(const int32_t * bias,int32_t size)562 static bool check_null_bias(const int32_t *bias, int32_t size)
563 {
564 	bool null_bias = true;
565 
566 	for (int i = 0; i < size; i++) {
567 		if (bias[i] != 0) {
568 			null_bias = false;
569 			break;
570 		}
571 	}
572 	return null_bias;
573 }
574 
ZTEST(cmsis_nn,test_svdf)575 ZTEST(cmsis_nn, test_svdf)
576 {
577 	cmsis_nn_context input_ctx;
578 	cmsis_nn_context output_ctx;
579 	cmsis_nn_svdf_params svdf_2_params;
580 	cmsis_nn_dims input_dims;
581 	cmsis_nn_dims weights_feature_dims;
582 	cmsis_nn_dims weights_time_dims;
583 	cmsis_nn_dims state_dims;
584 	cmsis_nn_dims output_dims;
585 	cmsis_nn_dims bias_dims;
586 	cmsis_nn_per_tensor_quant_params input_quant_params;
587 	cmsis_nn_per_tensor_quant_params output_quant_params;
588 	int8_t output_data[SVDF_2_DST_SIZE];
589 
590 	const int8_t *weights_feature_data = svdf_2_weights_feature;
591 	const int16_t *weights_time_data = svdf_2_weights_time;
592 
593 	input_dims.n = SVDF_2_INPUT_BATCHES;
594 	input_dims.h = SVDF_2_INPUT_SIZE;
595 	weights_feature_dims.n = SVDF_2_FEATURE_BATCHES;
596 	weights_time_dims.h = SVDF_2_TIME_BATCHES;
597 
598 	input_quant_params.multiplier = SVDF_2_MULTIPLIER_IN;
599 	input_quant_params.shift = SVDF_2_SHIFT_1;
600 	output_quant_params.multiplier = SVDF_2_MULTIPLIER_OUT;
601 	output_quant_params.shift = SVDF_2_SHIFT_2;
602 
603 	svdf_2_params.input_activation.min = SVDF_2_IN_ACTIVATION_MIN;
604 	svdf_2_params.input_activation.max = SVDF_2_IN_ACTIVATION_MAX;
605 	svdf_2_params.output_activation.min = SVDF_2_OUT_ACTIVATION_MIN;
606 	svdf_2_params.output_activation.max = SVDF_2_OUT_ACTIVATION_MAX;
607 	svdf_2_params.input_offset = SVDF_2_INPUT_OFFSET;
608 	svdf_2_params.output_offset = SVDF_2_OUTPUT_OFFSET;
609 	svdf_2_params.rank = SVDF_2_RANK;
610 
611 	const int input_round_size = SVDF_2_INPUT_BATCHES * SVDF_2_INPUT_SIZE;
612 	const int number_inputs = sizeof(svdf_2_input_sequence) / input_round_size;
613 	const int32_t number_units = SVDF_2_FEATURE_BATCHES / SVDF_2_RANK;
614 	const int scratch_size = SVDF_2_INPUT_BATCHES * SVDF_2_FEATURE_BATCHES * sizeof(int32_t);
615 	const int scratch_size_out = SVDF_2_INPUT_BATCHES * number_units * sizeof(int32_t);
616 
617 	input_ctx.buf = malloc(scratch_size);
618 	output_ctx.buf = malloc(scratch_size_out);
619 
620 	int8_t *input_data = malloc(input_round_size);
621 	int16_t *state_data = malloc(sizeof(svdf_2_state));
622 	const bool null_bias = check_null_bias(svdf_2_biases,
623 					       SVDF_2_DST_SIZE / SVDF_2_INPUT_BATCHES);
624 
625 	for (int i = 0; i < REPEAT_NUM; i++) {
626 		memcpy(state_data, svdf_2_state, sizeof(svdf_2_state));
627 		for (int j = 0; j < number_inputs; j++) {
628 			memcpy(input_data, svdf_2_input_sequence + j * input_round_size,
629 			       input_round_size);
630 			arm_cmsis_nn_status result = arm_svdf_state_s16_s8(&input_ctx,
631 							&output_ctx,
632 							&svdf_2_params,
633 							&input_quant_params,
634 							&output_quant_params,
635 							&input_dims,
636 							input_data,
637 							&state_dims,
638 							state_data,
639 							&weights_feature_dims,
640 							weights_feature_data,
641 							&weights_time_dims,
642 							weights_time_data,
643 							&bias_dims,
644 							null_bias == true ? NULL : svdf_2_biases,
645 							&output_dims,
646 							output_data);
647 			zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
648 		}
649 
650 		zassert_mem_equal(svdf_2_output_ref, output_data, sizeof(output_data), "");
651 	}
652 	free(state_data);
653 	free(input_data);
654 	free(input_ctx.buf);
655 	free(output_ctx.buf);
656 }
657 
658 ZTEST_SUITE(cmsis_nn, NULL, NULL, NULL, NULL, NULL);
659