1 /*
2  * Copyright (c) 2021, Commonwealth Scientific and Industrial Research
3  * Organisation (CSIRO) ABN 41 687 119 230.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * This is not exhaustive functional testing of the CMSIS-NN library.
8  *
9  * Individual tests have been pulled from CMSIS/NN/Tests/UnitTest to
10  * validate the integration of CMSIS-NN and Zephyr
11  */
12 
13 #include <zephyr/ztest.h>
14 #include <zephyr/kernel.h>
15 #include <stdlib.h>
16 
17 #include "arm_nnfunctions.h"
18 
19 #define REPEAT_NUM 3
20 
21 #define AVGPOOLING_2_OUT_CH		5
22 #define AVGPOOLING_2_IN_CH		5
23 #define AVGPOOLING_2_INPUT_W		12
24 #define AVGPOOLING_2_INPUT_H		1
25 #define AVGPOOLING_2_DST_SIZE		60
26 #define AVGPOOLING_2_INPUT_SIZE		60
27 #define AVGPOOLING_2_OUT_ACTIVATION_MIN -128
28 #define AVGPOOLING_2_OUT_ACTIVATION_MAX 127
29 #define AVGPOOLING_2_INPUT_BATCHES	1
30 #define AVGPOOLING_2_FILTER_X		3
31 #define AVGPOOLING_2_FILTER_Y		1
32 #define AVGPOOLING_2_STRIDE_X		1
33 #define AVGPOOLING_2_STRIDE_Y		2
34 #define AVGPOOLING_2_PAD_X		1
35 #define AVGPOOLING_2_PAD_Y		0
36 #define AVGPOOLING_2_OUTPUT_W		12
37 #define AVGPOOLING_2_OUTPUT_H		1
38 
39 const int8_t avgpooling_2_input[60] = {
40 	-82, -104, 10,	-28, -52, -51, -66, 52,	 124, -74, -21,	 4,  37,   -7,	-33,
41 	102, 110,  24,	52,  121, 13,  -55, -79, -92, -35, -103, 86, 95,   46,	32,
42 	-24, -123, 120, 29,  -77, -97, -69, -68, 58,  38,  3,	 3,  79,   -47, 112,
43 	-52, -113, -46, 107, 68,  83,  -70, 91,	 14,  113, 74,	 73, -103, -98, 25};
44 
45 const int8_t avgpooling_2_output_ref[60] = {
46 	-67, -85, 31, 48,  -63, -51, -55, 33,  30, -53, 10,  16,  38,  56,  5,
47 	31,  20,  -6, -16, 18,	4,   47,  13,  2,  39,	-38, -31, 45,  -6,  -27,
48 	-75, -35, 49, 44,  -2,	-39, -63, 44,  13, 24,	-49, -60, -12, 39,  73,
49 	11,  -60, 41, 25,  98,	35,  -37, -19, 8,  69,	79,  2,	  -6,  -42, 69};
50 
ZTEST(cmsis_nn,test_avgpool)51 ZTEST(cmsis_nn, test_avgpool)
52 {
53 	int8_t output[AVGPOOLING_2_DST_SIZE] = {0};
54 
55 	cmsis_nn_context ctx;
56 	cmsis_nn_pool_params pool_params;
57 	cmsis_nn_dims input_dims;
58 	cmsis_nn_dims filter_dims;
59 	cmsis_nn_dims output_dims;
60 
61 	input_dims.n = AVGPOOLING_2_INPUT_BATCHES;
62 	input_dims.w = AVGPOOLING_2_INPUT_W;
63 	input_dims.h = AVGPOOLING_2_INPUT_H;
64 	input_dims.c = AVGPOOLING_2_IN_CH;
65 	filter_dims.w = AVGPOOLING_2_FILTER_X;
66 	filter_dims.h = AVGPOOLING_2_FILTER_Y;
67 	output_dims.w = AVGPOOLING_2_OUTPUT_W;
68 	output_dims.h = AVGPOOLING_2_OUTPUT_H;
69 	output_dims.c = AVGPOOLING_2_OUT_CH;
70 
71 	pool_params.padding.w = AVGPOOLING_2_PAD_X;
72 	pool_params.padding.h = AVGPOOLING_2_PAD_Y;
73 	pool_params.stride.w = AVGPOOLING_2_STRIDE_X;
74 	pool_params.stride.h = AVGPOOLING_2_STRIDE_Y;
75 
76 	pool_params.activation.min = AVGPOOLING_2_OUT_ACTIVATION_MIN;
77 	pool_params.activation.max = AVGPOOLING_2_OUT_ACTIVATION_MAX;
78 
79 	ctx.size = arm_avgpool_s8_get_buffer_size(AVGPOOLING_2_OUTPUT_W, AVGPOOLING_2_IN_CH);
80 	ctx.buf = malloc(ctx.size);
81 
82 	arm_cmsis_nn_status result = arm_avgpool_s8(&ctx,
83 						   &pool_params,
84 						   &input_dims,
85 						   avgpooling_2_input,
86 						   &filter_dims,
87 						   &output_dims,
88 						   output);
89 
90 	free(ctx.buf);
91 
92 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
93 	zassert_mem_equal(avgpooling_2_output_ref, output, sizeof(output), "");
94 }
95 
96 #define CONV_4_OUT_CH		  3
97 #define CONV_4_IN_CH		  3
98 #define CONV_4_INPUT_W		  5
99 #define CONV_4_INPUT_H		  5
100 #define CONV_4_DST_SIZE		  36
101 #define CONV_4_INPUT_SIZE	  75
102 #define CONV_4_OUT_ACTIVATION_MIN -109
103 #define CONV_4_OUT_ACTIVATION_MAX 127
104 #define CONV_4_INPUT_BATCHES	  3
105 #define CONV_4_FILTER_X		  2
106 #define CONV_4_FILTER_Y		  3
107 #define CONV_4_STRIDE_X		  2
108 #define CONV_4_STRIDE_Y		  2
109 #define CONV_4_PAD_X		  0
110 #define CONV_4_PAD_Y		  0
111 #define CONV_4_OUTPUT_W		  2
112 #define CONV_4_OUTPUT_H		  2
113 #define CONV_4_INPUT_OFFSET	  128
114 #define CONV_4_OUTPUT_OFFSET	  -128
115 #define CONV_4_DILATION_X	  1
116 #define CONV_4_DILATION_Y	  1
117 
118 const int32_t conv_4_biases[3] = {13175, 9050, 18215};
119 
120 const int8_t conv_4_weights[54] = {
121 	-25, -83, -74,	105, 30,  118, -32, 127, 34,  127, -112, 39, -43, 104, 41,  -124, 115, 5,
122 	42,  -48, -119, 93,  17,  57,  41,  -41, -42, 23,  127,	 18, 70,  -99, 71,  67,	  83,  76,
123 	-50, 98,  66,	64,  127, -6,  -77, -48, -26, 45,  77,	 1,  81,  27,  124, -103, 37,  36};
124 
125 const int8_t conv_4_input[225] = {
126 	82,   120,  -97, -44,  -118, 73,   4,	 -84,  -53,  -122, -15,	 77,   83,  43,	  37,
127 	85,   -11,  103, 45,   -69,  -12,  -8,	 21,   6,    -68,  -83,	 -15,  -99, 90,	  -62,
128 	95,   62,   -38, -32,  -35,  -105, -53,	 70,   112,  14,   -4,	 -33,  -26, -93,  -98,
129 	22,   -5,   22,	 -104, 57,   -92,  30,	 -62,  0,    -43,  -82,	 60,   99,  -83,  32,
130 	94,   49,   10,	 112,  -71,  -27,  -91,	 -79,  52,   -92,  -71,	 86,   -79, -15,  -80,
131 	-74,  -4,   76,	 -119, 91,   -23,  -12,	 -111, -72,  26,   11,	 64,   116, 38,	  99,
132 	125,  17,   6,	 -4,   46,   119,  113,	 -116, -125, 80,   -57,	 122,  75,  119,  -117,
133 	87,   -121, -70, -75,  -127, 16,   -124, -110, 10,   71,   29,	 27,   37,  -24,  52,
134 	28,   -100, 86,	 -75,  117,  -31,  -115, -86,  -122, 121,  -96,	 -118, 32,  111,  25,
135 	-90,  -8,   110, 37,   35,   124,  -123, 94,   -122, -114, 37,	 85,   -36, 53,	  -40,
136 	73,   -99,  27,	 10,   37,   41,   64,	 -97,  -123, 75,   0,	 -107, -72, 58,	  -100,
137 	17,   77,   114, 120,  -83,  -96,  75,	 -12,  -27,  3,	   35,	 85,   4,   119,  -20,
138 	28,   99,   104, -78,  -51,  -82,  -92,	 -40,  -116, 35,   -107, 39,   9,   -120, -50,
139 	-102, -114, 25,	 -77,  25,   7,	   64,	 110,  80,   -93,  -20,	 34,   115, 75,	  37,
140 	47,   16,   6,	 -92,  -25,  37,   69,	 82,   -61,  -100, -85,	 -51,  6,   -95,  58
141 };
142 
143 const int32_t conv_4_output_mult[3] = {2039209398, 2005068758, 2023002003};
144 
145 const int32_t conv_4_output_shift[3] = {-9, -9, -9};
146 
147 const int8_t conv_4_output_ref[36] = {-5,   -39, -31, 20,  -37, -26, -109, -7,	-10, -51, -58, 48,
148 				      -100, -32, 24,  4,   69,	-38, -64,  65,	-34, 95,  -55, 39,
149 				      95,   -54, 27,  -49, 25,	-68, -109, -66, 72,  38,  -44, -40};
150 
ZTEST(cmsis_nn,test_convolve)151 ZTEST(cmsis_nn, test_convolve)
152 {
153 	int8_t output[CONV_4_DST_SIZE] = {0};
154 
155 	cmsis_nn_context ctx;
156 	cmsis_nn_conv_params conv_params;
157 	cmsis_nn_per_channel_quant_params quant_params;
158 	cmsis_nn_dims input_dims;
159 	cmsis_nn_dims filter_dims;
160 	cmsis_nn_dims bias_dims;
161 	cmsis_nn_dims output_dims;
162 
163 	const int32_t *bias_data = conv_4_biases;
164 	const int8_t *kernel_data = conv_4_weights;
165 	const int8_t *input_data = conv_4_input;
166 
167 	input_dims.n = CONV_4_INPUT_BATCHES;
168 	input_dims.w = CONV_4_INPUT_W;
169 	input_dims.h = CONV_4_INPUT_H;
170 	input_dims.c = CONV_4_IN_CH;
171 	filter_dims.w = CONV_4_FILTER_X;
172 	filter_dims.h = CONV_4_FILTER_Y;
173 	filter_dims.c = CONV_4_IN_CH;
174 	output_dims.w = CONV_4_OUTPUT_W;
175 	output_dims.h = CONV_4_OUTPUT_H;
176 	output_dims.c = CONV_4_OUT_CH;
177 
178 	conv_params.padding.w = CONV_4_PAD_X;
179 	conv_params.padding.h = CONV_4_PAD_Y;
180 	conv_params.stride.w = CONV_4_STRIDE_X;
181 	conv_params.stride.h = CONV_4_STRIDE_Y;
182 	conv_params.dilation.w = CONV_4_DILATION_X;
183 	conv_params.dilation.h = CONV_4_DILATION_Y;
184 
185 	conv_params.input_offset = CONV_4_INPUT_OFFSET;
186 	conv_params.output_offset = CONV_4_OUTPUT_OFFSET;
187 	conv_params.activation.min = CONV_4_OUT_ACTIVATION_MIN;
188 	conv_params.activation.max = CONV_4_OUT_ACTIVATION_MAX;
189 	quant_params.multiplier = (int32_t *)conv_4_output_mult;
190 	quant_params.shift = (int32_t *)conv_4_output_shift;
191 
192 	int32_t buf_size = arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims);
193 
194 	ctx.buf = malloc(buf_size);
195 	ctx.size = 0;
196 
197 	arm_cmsis_nn_status result = arm_convolve_s8(&ctx,
198 					    &conv_params,
199 					    &quant_params,
200 					    &input_dims,
201 					    input_data,
202 					    &filter_dims,
203 					    kernel_data,
204 					    &bias_dims,
205 					    bias_data,
206 					    NULL,
207 					    &output_dims,
208 					    output);
209 
210 	free(ctx.buf);
211 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
212 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
213 
214 	buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
215 							   &filter_dims, &output_dims);
216 	ctx.buf = malloc(buf_size);
217 	ctx.size = 0;
218 
219 	result = arm_convolve_wrapper_s8(&ctx,
220 					 &conv_params,
221 					 &quant_params,
222 					 &input_dims,
223 					 input_data,
224 					 &filter_dims,
225 					 kernel_data,
226 					 &bias_dims,
227 					 bias_data,
228 					 &output_dims,
229 					 output);
230 
231 	free(ctx.buf);
232 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
233 	zassert_mem_equal(conv_4_output_ref, output, sizeof(output), "");
234 }
235 
236 #define STRIDE2PAD1_OUT_CH	       1
237 #define STRIDE2PAD1_IN_CH	       1
238 #define STRIDE2PAD1_INPUT_W	       7
239 #define STRIDE2PAD1_INPUT_H	       7
240 #define STRIDE2PAD1_DST_SIZE	       16
241 #define STRIDE2PAD1_INPUT_SIZE	       49
242 #define STRIDE2PAD1_OUT_ACTIVATION_MIN -128
243 #define STRIDE2PAD1_OUT_ACTIVATION_MAX 127
244 #define STRIDE2PAD1_INPUT_BATCHES      1
245 #define STRIDE2PAD1_FILTER_X	       3
246 #define STRIDE2PAD1_FILTER_Y	       3
247 #define STRIDE2PAD1_STRIDE_X	       2
248 #define STRIDE2PAD1_STRIDE_Y	       2
249 #define STRIDE2PAD1_PAD_X	       1
250 #define STRIDE2PAD1_PAD_Y	       1
251 #define STRIDE2PAD1_OUTPUT_W	       4
252 #define STRIDE2PAD1_OUTPUT_H	       4
253 #define STRIDE2PAD1_INPUT_OFFSET       128
254 #define STRIDE2PAD1_OUTPUT_OFFSET      -20
255 #define STRIDE2PAD1_DILATION_X	       1
256 #define STRIDE2PAD1_DILATION_Y	       1
257 
258 const int32_t stride2pad1_biases[1] = {-9794};
259 
260 const int8_t stride2pad1_weights[9] = {-54, 57, -19, -127, 87, 70, 74, -110, 66};
261 
262 const int8_t stride2pad1_input[49] = {
263 	-91, -30, -57, -76, 32,	 -13, 14,   -96, 108, -4,  41,	48,  107, -68,	-101, 30,  95,
264 	95,  91,  -66, -80, 114, -49, 7,    -67, -35, -1,  -88, -77, -56, -103, 5,    -39, -118,
265 	-24, -32, 67,  11,  38,	 -16, -124, 44,	 -46, -92, -24, 108, 80,  -29,	-3};
266 
267 const int32_t stride2pad1_output_mult[1] = {2033801520};
268 
269 const int32_t stride2pad1_output_shift[1] = {-8};
270 
271 const int8_t stride2pad1_output_ref[16] = {26, -11, 33,	 -25,  -96, -52, -78, -86,
272 					   33, -2,  -88, -113, -14, 0,	 -84, -27};
273 
ZTEST(cmsis_nn,test_depthwise_convolve)274 ZTEST(cmsis_nn, test_depthwise_convolve)
275 {
276 	int8_t output[STRIDE2PAD1_DST_SIZE] = {0};
277 
278 	cmsis_nn_context ctx;
279 	cmsis_nn_dw_conv_params dw_conv_params;
280 	cmsis_nn_per_channel_quant_params quant_params;
281 	cmsis_nn_dims input_dims;
282 	cmsis_nn_dims filter_dims;
283 	cmsis_nn_dims bias_dims = {0};
284 	cmsis_nn_dims output_dims;
285 
286 	const int32_t *bias_data = stride2pad1_biases;
287 	const int8_t *kernel_data = stride2pad1_weights;
288 	const int8_t *input_data = stride2pad1_input;
289 
290 	input_dims.n = STRIDE2PAD1_INPUT_BATCHES;
291 	input_dims.w = STRIDE2PAD1_INPUT_W;
292 	input_dims.h = STRIDE2PAD1_INPUT_H;
293 	input_dims.c = STRIDE2PAD1_IN_CH;
294 	filter_dims.w = STRIDE2PAD1_FILTER_X;
295 	filter_dims.h = STRIDE2PAD1_FILTER_Y;
296 	output_dims.w = STRIDE2PAD1_OUTPUT_W;
297 	output_dims.h = STRIDE2PAD1_OUTPUT_H;
298 	output_dims.c = STRIDE2PAD1_OUT_CH;
299 
300 	dw_conv_params.padding.w = STRIDE2PAD1_PAD_X;
301 	dw_conv_params.padding.h = STRIDE2PAD1_PAD_Y;
302 	dw_conv_params.stride.w = STRIDE2PAD1_STRIDE_X;
303 	dw_conv_params.stride.h = STRIDE2PAD1_STRIDE_Y;
304 	dw_conv_params.dilation.w = STRIDE2PAD1_DILATION_X;
305 	dw_conv_params.dilation.h = STRIDE2PAD1_DILATION_Y;
306 
307 	dw_conv_params.ch_mult = 1;
308 
309 	dw_conv_params.input_offset = STRIDE2PAD1_INPUT_OFFSET;
310 	dw_conv_params.output_offset = STRIDE2PAD1_OUTPUT_OFFSET;
311 	dw_conv_params.activation.min = STRIDE2PAD1_OUT_ACTIVATION_MIN;
312 	dw_conv_params.activation.max = STRIDE2PAD1_OUT_ACTIVATION_MAX;
313 	quant_params.multiplier = (int32_t *)stride2pad1_output_mult;
314 	quant_params.shift = (int32_t *)stride2pad1_output_shift;
315 
316 	ctx.buf = NULL;
317 	ctx.size = 0;
318 
319 	arm_cmsis_nn_status result = arm_depthwise_conv_s8(&ctx,
320 						  &dw_conv_params,
321 						  &quant_params,
322 						  &input_dims,
323 						  input_data,
324 						  &filter_dims,
325 						  kernel_data,
326 						  &bias_dims,
327 						  bias_data,
328 						  &output_dims,
329 						  output);
330 
331 	free(ctx.buf);
332 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
333 	zassert_mem_equal(stride2pad1_output_ref, output, sizeof(output), "");
334 }
335 
336 #define FULLY_CONNECTED_MVE_0_OUT_CH		 9
337 #define FULLY_CONNECTED_MVE_0_IN_CH		 16
338 #define FULLY_CONNECTED_MVE_0_INPUT_W		 1
339 #define FULLY_CONNECTED_MVE_0_INPUT_H		 1
340 #define FULLY_CONNECTED_MVE_0_DST_SIZE		 9
341 #define FULLY_CONNECTED_MVE_0_INPUT_SIZE	 16
342 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN -128
343 #define FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX 127
344 #define FULLY_CONNECTED_MVE_0_INPUT_BATCHES	 1
345 #define FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER	 1244038257
346 #define FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT	 -9
347 #define FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH 16
348 #define FULLY_CONNECTED_MVE_0_INPUT_OFFSET	 128
349 #define FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET	 -26
350 
351 const int32_t fully_connected_mve_0_biases[9] = {11295, -30752, -3196, 10489, -5120,
352 						 18598, 27393,	29746, 22967};
353 
354 const int8_t fully_connected_mve_0_input[16] = {-43, 68,  79,	-12, -119, -56, -102, -46,
355 						107, -65, -109, -7,  92,   -99, -80,  -29};
356 
357 const int8_t fully_connected_mve_0_output_ref[9] = {-9, -3, 26, 8, 3, -88, 75, 34, 5};
358 
359 const int8_t fully_connected_mve_0_weights[144] = {
360 	37,  -46,  75,	 -33,  -52, -82,  -94,	64,   71,  65,	 64,  16,   -66, -5,   -65,  -44,
361 	82,  42,   84,	 105,  18,  79,	  -103, -75,  -95, 65,	 87,  103,  43,	 -25,  -66,  75,
362 	125, 40,   -34,	 24,   9,   -79,  4,	73,   98,  -75,	 42,  81,   18,	 -58,  -119, 92,
363 	0,   -72,  48,	 23,   -69, 11,	  -95,	-103, 66,  117,	 107, -96,  114, -29,  75,   -93,
364 	118, 66,   -19,	 83,   -14, 86,	  -110, 44,   37,  -9,	 17,  -107, 50,	 -116, -116, -27,
365 	-84, -126, -108, -127, -71, 8,	  81,	108,  -61, 126,	 69,  -45,  37,	 -78,  -102, -55,
366 	116, 112,  -111, -89,  -57, 82,	  -47,	22,   125, -84,	 97,  -9,   88,	 74,   -15,  118,
367 	-95, 112,  89,	 44,   -17, -112, -71,	-94,  1,   -117, 112, -92,  52,	 57,   -22,  80,
368 	-60, 95,   -106, -1,   -27, 105,  6,	123,  6,   96,	 126, -65,  -29, 103,  19,   -45};
369 
ZTEST(cmsis_nn,test_fully_connected)370 ZTEST(cmsis_nn, test_fully_connected)
371 {
372 	int8_t output[FULLY_CONNECTED_MVE_0_DST_SIZE] = {0};
373 
374 	cmsis_nn_context ctx;
375 	cmsis_nn_fc_params fc_params;
376 	cmsis_nn_per_tensor_quant_params quant_params;
377 	cmsis_nn_dims input_dims;
378 	cmsis_nn_dims filter_dims;
379 	cmsis_nn_dims bias_dims;
380 	cmsis_nn_dims output_dims;
381 
382 	const int32_t *bias_data = fully_connected_mve_0_biases;
383 	const int8_t *kernel_data = fully_connected_mve_0_weights;
384 	const int8_t *input_data = fully_connected_mve_0_input;
385 
386 	input_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
387 	input_dims.w = FULLY_CONNECTED_MVE_0_INPUT_W;
388 	input_dims.h = FULLY_CONNECTED_MVE_0_INPUT_H;
389 	input_dims.c = FULLY_CONNECTED_MVE_0_IN_CH;
390 	filter_dims.n = FULLY_CONNECTED_MVE_0_ACCUMULATION_DEPTH;
391 	filter_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
392 	output_dims.n = FULLY_CONNECTED_MVE_0_INPUT_BATCHES;
393 	output_dims.c = FULLY_CONNECTED_MVE_0_OUT_CH;
394 
395 	fc_params.input_offset = FULLY_CONNECTED_MVE_0_INPUT_OFFSET;
396 	fc_params.filter_offset = 0;
397 	fc_params.output_offset = FULLY_CONNECTED_MVE_0_OUTPUT_OFFSET;
398 	fc_params.activation.min = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MIN;
399 	fc_params.activation.max = FULLY_CONNECTED_MVE_0_OUT_ACTIVATION_MAX;
400 
401 	quant_params.multiplier = FULLY_CONNECTED_MVE_0_OUTPUT_MULTIPLIER;
402 	quant_params.shift = FULLY_CONNECTED_MVE_0_OUTPUT_SHIFT;
403 
404 	int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
405 
406 	ctx.buf = malloc(buf_size);
407 	ctx.size = buf_size;
408 	arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
409 						   &fc_params,
410 						   &quant_params,
411 						   &input_dims,
412 						   input_data,
413 						   &filter_dims,
414 						   kernel_data,
415 						   &bias_dims,
416 						   bias_data,
417 						   &output_dims,
418 						   output);
419 
420 	free(ctx.buf);
421 	zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
422 	zassert_mem_equal(fully_connected_mve_0_output_ref, output, sizeof(output), "");
423 }
424 
425 #define MAXPOOLING_2_OUT_CH		5
426 #define MAXPOOLING_2_IN_CH		5
427 #define MAXPOOLING_2_INPUT_W		12
428 #define MAXPOOLING_2_INPUT_H		1
429 #define MAXPOOLING_2_DST_SIZE		60
430 #define MAXPOOLING_2_INPUT_SIZE		60
431 #define MAXPOOLING_2_OUT_ACTIVATION_MIN -128
432 #define MAXPOOLING_2_OUT_ACTIVATION_MAX 127
433 #define MAXPOOLING_2_INPUT_BATCHES	1
434 #define MAXPOOLING_2_FILTER_X		3
435 #define MAXPOOLING_2_FILTER_Y		1
436 #define MAXPOOLING_2_STRIDE_X		1
437 #define MAXPOOLING_2_STRIDE_Y		2
438 #define MAXPOOLING_2_PAD_X		1
439 #define MAXPOOLING_2_PAD_Y		0
440 #define MAXPOOLING_2_OUTPUT_W		12
441 #define MAXPOOLING_2_OUTPUT_H		1
442 
443 const int8_t maxpooling_2_input[60] = {
444 	75,  -52,  -42,	 -30, 56,  64,	 106, -36, 120, -3,  34,   -105, 69,   75,  -39,
445 	15,  93,   -71,	 39,  34,  -11,	 65,  22,  59,	106, 105,  45,	 -116, -75, 123,
446 	-65, 75,   -61,	 13,  -25, -123, 59,  110, -65, 86,  -108, -107, -17,  38,  27,
447 	-1,  -115, -123, 75,  -75, 68,	 52,  12,  -35, 116, -68,  22,	 15,   76,  -81};
448 
449 const int8_t maxpooling_2_output_ref[60] = {
450 	75,  106, -36, 120, 56,	 75,  106, 69,	120, 56,  64,  106, 69,	 120, 34,
451 	34,  93,  69,  75,  106, 105, 93,  22,	59,  123, 105, 75,  22,	 59,  123,
452 	105, 75,  110, 13,  123, -65, 75,  110, 38,  86,  -1,  59,  110, 75,  86,
453 	68,  52,  12,  75,  116, 68,  52,  15,	76,  116, 68,  52,  15,	 76,  116};
454 
ZTEST(cmsis_nn,test_max_pool)455 ZTEST(cmsis_nn, test_max_pool)
456 {
457 	int8_t output[MAXPOOLING_2_DST_SIZE] = {0};
458 
459 	cmsis_nn_context ctx;
460 	cmsis_nn_pool_params pool_params;
461 	cmsis_nn_dims input_dims;
462 	cmsis_nn_dims filter_dims;
463 	cmsis_nn_dims output_dims;
464 
465 	const int8_t *input_data = maxpooling_2_input;
466 
467 	input_dims.n = MAXPOOLING_2_INPUT_BATCHES;
468 	input_dims.w = MAXPOOLING_2_INPUT_W;
469 	input_dims.h = MAXPOOLING_2_INPUT_H;
470 	input_dims.c = MAXPOOLING_2_IN_CH;
471 	filter_dims.w = MAXPOOLING_2_FILTER_X;
472 	filter_dims.h = MAXPOOLING_2_FILTER_Y;
473 	output_dims.w = MAXPOOLING_2_OUTPUT_W;
474 	output_dims.h = MAXPOOLING_2_OUTPUT_H;
475 	output_dims.c = MAXPOOLING_2_OUT_CH;
476 
477 	pool_params.padding.w = MAXPOOLING_2_PAD_X;
478 	pool_params.padding.h = MAXPOOLING_2_PAD_Y;
479 	pool_params.stride.w = MAXPOOLING_2_STRIDE_X;
480 	pool_params.stride.h = MAXPOOLING_2_STRIDE_Y;
481 
482 	pool_params.activation.min = MAXPOOLING_2_OUT_ACTIVATION_MIN;
483 	pool_params.activation.max = MAXPOOLING_2_OUT_ACTIVATION_MAX;
484 
485 	for (int i = 0; i < REPEAT_NUM; i++) {
486 		arm_cmsis_nn_status result =
487 			arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
488 					&output_dims, output);
489 
490 		zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
491 		zassert_mem_equal(maxpooling_2_output_ref, output, sizeof(output), "");
492 	}
493 }
494 
495 #define SOFTMAX_NUM_ROWS	 2
496 #define SOFTMAX_ROW_SIZE	 5
497 #define SOFTMAX_INPUT_MULT	 1077952640
498 #define SOFTMAX_INPUT_LEFT_SHIFT 19
499 #define SOFTMAX_DIFF_MIN	 -3968
500 #define SOFTMAX_DST_SIZE	 10
501 
502 const int8_t softmax_input[10] = {101, 49, 6, -34, -75, -79, -38, 120, -55, 115};
503 
504 const int8_t softmax_output_ref[10] = {-57, -70, -79, -86, -92, -94, -88, -54, -91, -56};
505 
ZTEST(cmsis_nn,test_softmax)506 ZTEST(cmsis_nn, test_softmax)
507 {
508 	const int32_t num_rows = SOFTMAX_NUM_ROWS;
509 	const int32_t row_size = SOFTMAX_ROW_SIZE;
510 	const int32_t mult = SOFTMAX_INPUT_MULT;
511 	const int32_t shift = SOFTMAX_INPUT_LEFT_SHIFT;
512 	const int32_t diff_min = SOFTMAX_DIFF_MIN;
513 	const int8_t *input_data = softmax_input;
514 	int8_t output[SOFTMAX_DST_SIZE];
515 
516 	for (int i = 0; i < REPEAT_NUM; i++) {
517 		arm_softmax_s8(input_data, num_rows, row_size, mult, shift, diff_min, output);
518 		zassert_mem_equal(softmax_output_ref, output, sizeof(output), "");
519 	}
520 }
521 
522 #define SVDF_2_MULTIPLIER_IN	  1717987072
523 #define SVDF_2_MULTIPLIER_OUT	  1099511552
524 #define SVDF_2_SHIFT_1		  -3
525 #define SVDF_2_SHIFT_2		  -11
526 #define SVDF_2_IN_ACTIVATION_MIN  -32768
527 #define SVDF_2_IN_ACTIVATION_MAX  32767
528 #define SVDF_2_RANK		  2
529 #define SVDF_2_FEATURE_BATCHES	  10
530 #define SVDF_2_TIME_BATCHES	  2
531 #define SVDF_2_INPUT_SIZE	  7
532 #define SVDF_2_DST_SIZE		  15
533 #define SVDF_2_OUT_ACTIVATION_MIN -128
534 #define SVDF_2_OUT_ACTIVATION_MAX 127
535 #define SVDF_2_INPUT_BATCHES	  3
536 #define SVDF_2_INPUT_OFFSET	  0
537 #define SVDF_2_OUTPUT_OFFSET	  0
538 
539 const int32_t svdf_2_biases[5] = {0, 0, 0, 0, 0};
540 
541 const int16_t svdf_2_state[60] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
542 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
543 				  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
544 
545 const int8_t svdf_2_weights_feature[70] = {
546 	27,   82,  -108, -127, 85,  3,	 -51, 32,  110, -6,  -14, -16,	31,  101,
547 	-122, 19,  76,	 74,   -80, 12,	 -22, -17, 10,	-28, 55,  109,	2,   -107,
548 	-4,   72,  -65,	 -59,  36,  -69, 105, -97, 25,	38,  110, -121, -88, -126,
549 	-14,  16,  -88,	 -66,  3,   -93, 69,  -64, 44,	103, 95,  -95,	68,  -46,
550 	106,  -31, -63,	 23,   -38, 36,	 -95, -43, 93,	77,  91,  -26,	33,  59};
551 
552 const int16_t svdf_2_weights_time[20] = {-31, -88, -10, -72, -119, -6, -70, 63,	 -10, 93,
553 					 5,   42,  -6,	22,  6,	   51, 37,  -38, 5,   117};
554 
555 const int8_t svdf_2_input_sequence[42] = {
556 	29,   81,   -38, 17,  -116, 43,	 119,  -127, 74,  115, 9,    118,  7,	 -56,
557 	-53,  -14,  -98, 60,  -128, 10,	 28,   -18,  12,  -28, -126, 87,   -115, -44,
558 	-123, -109, -59, -87, -69,  121, -128, -95,  -70, 2,   81,   -119, 84,	 -122};
559 
560 const int8_t svdf_2_output_ref[15] = {-53, 45,	27, -24, -53, 26, -82, -38,
561 				      11,  -85, 94, -16, -32, 31, 4};
562 
check_null_bias(const int32_t * bias,int32_t size)563 static bool check_null_bias(const int32_t *bias, int32_t size)
564 {
565 	bool null_bias = true;
566 
567 	for (int i = 0; i < size; i++) {
568 		if (bias[i] != 0) {
569 			null_bias = false;
570 			break;
571 		}
572 	}
573 	return null_bias;
574 }
575 
ZTEST(cmsis_nn,test_svdf)576 ZTEST(cmsis_nn, test_svdf)
577 {
578 	cmsis_nn_context input_ctx;
579 	cmsis_nn_context output_ctx;
580 	cmsis_nn_svdf_params svdf_2_params;
581 	cmsis_nn_dims input_dims;
582 	cmsis_nn_dims weights_feature_dims;
583 	cmsis_nn_dims weights_time_dims;
584 	cmsis_nn_dims state_dims;
585 	cmsis_nn_dims output_dims;
586 	cmsis_nn_dims bias_dims;
587 	cmsis_nn_per_tensor_quant_params input_quant_params;
588 	cmsis_nn_per_tensor_quant_params output_quant_params;
589 	int8_t output_data[SVDF_2_DST_SIZE];
590 
591 	const int8_t *weights_feature_data = svdf_2_weights_feature;
592 	const int16_t *weights_time_data = svdf_2_weights_time;
593 
594 	input_dims.n = SVDF_2_INPUT_BATCHES;
595 	input_dims.h = SVDF_2_INPUT_SIZE;
596 	weights_feature_dims.n = SVDF_2_FEATURE_BATCHES;
597 	weights_time_dims.h = SVDF_2_TIME_BATCHES;
598 
599 	input_quant_params.multiplier = SVDF_2_MULTIPLIER_IN;
600 	input_quant_params.shift = SVDF_2_SHIFT_1;
601 	output_quant_params.multiplier = SVDF_2_MULTIPLIER_OUT;
602 	output_quant_params.shift = SVDF_2_SHIFT_2;
603 
604 	svdf_2_params.input_activation.min = SVDF_2_IN_ACTIVATION_MIN;
605 	svdf_2_params.input_activation.max = SVDF_2_IN_ACTIVATION_MAX;
606 	svdf_2_params.output_activation.min = SVDF_2_OUT_ACTIVATION_MIN;
607 	svdf_2_params.output_activation.max = SVDF_2_OUT_ACTIVATION_MAX;
608 	svdf_2_params.input_offset = SVDF_2_INPUT_OFFSET;
609 	svdf_2_params.output_offset = SVDF_2_OUTPUT_OFFSET;
610 	svdf_2_params.rank = SVDF_2_RANK;
611 
612 	const int input_round_size = SVDF_2_INPUT_BATCHES * SVDF_2_INPUT_SIZE;
613 	const int number_inputs = sizeof(svdf_2_input_sequence) / input_round_size;
614 	const int32_t number_units = SVDF_2_FEATURE_BATCHES / SVDF_2_RANK;
615 	const int scratch_size = SVDF_2_INPUT_BATCHES * SVDF_2_FEATURE_BATCHES * sizeof(int32_t);
616 	const int scratch_size_out = SVDF_2_INPUT_BATCHES * number_units * sizeof(int32_t);
617 
618 	input_ctx.buf = malloc(scratch_size);
619 	output_ctx.buf = malloc(scratch_size_out);
620 
621 	int8_t *input_data = malloc(input_round_size);
622 	int16_t *state_data = malloc(sizeof(svdf_2_state));
623 	const bool null_bias = check_null_bias(svdf_2_biases,
624 					       SVDF_2_DST_SIZE / SVDF_2_INPUT_BATCHES);
625 
626 	for (int i = 0; i < REPEAT_NUM; i++) {
627 		memcpy(state_data, svdf_2_state, sizeof(svdf_2_state));
628 		for (int j = 0; j < number_inputs; j++) {
629 			memcpy(input_data, svdf_2_input_sequence + j * input_round_size,
630 			       input_round_size);
631 			arm_cmsis_nn_status result = arm_svdf_state_s16_s8(&input_ctx,
632 							&output_ctx,
633 							&svdf_2_params,
634 							&input_quant_params,
635 							&output_quant_params,
636 							&input_dims,
637 							input_data,
638 							&state_dims,
639 							state_data,
640 							&weights_feature_dims,
641 							weights_feature_data,
642 							&weights_time_dims,
643 							weights_time_data,
644 							&bias_dims,
645 							null_bias == true ? NULL : svdf_2_biases,
646 							&output_dims,
647 							output_data);
648 			zassert_equal(ARM_CMSIS_NN_SUCCESS, result, "");
649 		}
650 
651 		zassert_mem_equal(svdf_2_output_ref, output_data, sizeof(output_data), "");
652 	}
653 	free(state_data);
654 	free(input_data);
655 	free(input_ctx.buf);
656 	free(output_ctx.buf);
657 }
658 
659 ZTEST_SUITE(cmsis_nn, NULL, NULL, NULL, NULL, NULL);
660