1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> All rights
3  * reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * Licensed under the Apache License, Version 2.0 (the License); you may
8  * not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include <stdlib.h>
21 
22 #include <arm_nnfunctions.h>
23 #include <unity.h>
24 
25 #include "../TestData/int16xint8/test_data.h"
26 #include "../TestData/int16xint8_dilation_1/test_data.h"
27 #include "../TestData/int16xint8_dilation_2/test_data.h"
28 #include "../TestData/int16xint8_dilation_3/test_data.h"
29 #include "../TestData/int16xint8_spill/test_data.h"
30 #include "../TestData/int16xint8_spill2/test_data.h"
31 #include "../TestData/int16xint8xint32_1/test_data.h"
32 #include "../TestData/int16xint8xint32_2/test_data.h"
33 #include "../TestData/int16xint8xint32_3/test_data.h"
34 #include "../TestData/int16xint8xint32_4/test_data.h"
35 #include "../TestData/int16xint8xint32_5/test_data.h"
36 #include "../TestData/int16xint8xint32_6/test_data.h"
37 #include "../TestData/requantize_s64/test_data.h"
38 #include "../Utils/validate.h"
39 
int16xint8_arm_convolve_s16(void)40 void int16xint8_arm_convolve_s16(void)
41 {
42     int16_t output[INT16XINT8_DST_SIZE] = {0};
43 
44     cmsis_nn_context ctx;
45     cmsis_nn_conv_params conv_params;
46     cmsis_nn_per_channel_quant_params quant_params;
47     cmsis_nn_dims input_dims;
48     cmsis_nn_dims filter_dims;
49     cmsis_nn_dims bias_dims;
50     cmsis_nn_dims output_dims;
51 
52     const int64_t *int64_bias_data = int16xint8_biases;
53     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
54     const int8_t *kernel_data = int16xint8_weights;
55     const int16_t *input_data = int16xint8_input;
56     const int16_t *output_ref = int16xint8_output_ref;
57     const int32_t output_ref_size = INT16XINT8_DST_SIZE;
58 
59     input_dims.n = INT16XINT8_INPUT_BATCHES;
60     input_dims.w = INT16XINT8_INPUT_W;
61     input_dims.h = INT16XINT8_INPUT_H;
62     input_dims.c = INT16XINT8_IN_CH;
63     filter_dims.w = INT16XINT8_FILTER_X;
64     filter_dims.h = INT16XINT8_FILTER_Y;
65     output_dims.w = INT16XINT8_OUTPUT_W;
66     output_dims.h = INT16XINT8_OUTPUT_H;
67     output_dims.c = INT16XINT8_OUT_CH;
68 
69     conv_params.padding.w = INT16XINT8_PAD_X;
70     conv_params.padding.h = INT16XINT8_PAD_Y;
71     conv_params.stride.w = INT16XINT8_STRIDE_X;
72     conv_params.stride.h = INT16XINT8_STRIDE_Y;
73     conv_params.dilation.w = INT16XINT8_DILATION_X;
74     conv_params.dilation.h = INT16XINT8_DILATION_Y;
75 
76     conv_params.input_offset = 0;
77     conv_params.output_offset = 0;
78     conv_params.activation.min = INT16XINT8_OUT_ACTIVATION_MIN;
79     conv_params.activation.max = INT16XINT8_OUT_ACTIVATION_MAX;
80     quant_params.multiplier = (int32_t *)int16xint8_output_mult;
81     quant_params.shift = (int32_t *)int16xint8_output_shift;
82 
83     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
84     ctx.buf = malloc(buf_size);
85     arm_cmsis_nn_status result;
86     result = arm_convolve_s16(&ctx,
87                               &conv_params,
88                               &quant_params,
89                               &input_dims,
90                               input_data,
91                               &filter_dims,
92                               kernel_data,
93                               &bias_dims,
94                               &bias_data,
95                               &output_dims,
96                               output);
97     if (ctx.buf)
98     {
99         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
100         memset(ctx.buf, 0, buf_size);
101         free(ctx.buf);
102     }
103     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
104     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
105     memset(output, 0, sizeof(output));
106 
107     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
108     ctx.buf = malloc(buf_size);
109 
110     result = arm_convolve_wrapper_s16(&ctx,
111                                       &conv_params,
112                                       &quant_params,
113                                       &input_dims,
114                                       input_data,
115                                       &filter_dims,
116                                       kernel_data,
117                                       &bias_dims,
118                                       &bias_data,
119                                       &output_dims,
120                                       output);
121     if (ctx.buf)
122     {
123         memset(ctx.buf, 0, buf_size);
124         free(ctx.buf);
125     }
126     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
127     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
128 }
129 
requantize_s64_arm_convolve_s16(void)130 void requantize_s64_arm_convolve_s16(void)
131 {
132     int16_t output[REQUANTIZE_S64_DST_SIZE] = {0};
133 
134     cmsis_nn_context ctx;
135     cmsis_nn_conv_params conv_params;
136     cmsis_nn_per_channel_quant_params quant_params;
137     cmsis_nn_dims input_dims;
138     cmsis_nn_dims filter_dims;
139     cmsis_nn_dims bias_dims;
140     cmsis_nn_dims output_dims;
141 
142     const int64_t *int64_bias_data = requantize_s64_biases;
143     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
144     const int8_t *kernel_data = requantize_s64_weights;
145     const int16_t *input_data = requantize_s64_input;
146     const int16_t *output_ref = requantize_s64_output_ref;
147     const int32_t output_ref_size = REQUANTIZE_S64_DST_SIZE;
148 
149     input_dims.n = REQUANTIZE_S64_INPUT_BATCHES;
150     input_dims.w = REQUANTIZE_S64_INPUT_W;
151     input_dims.h = REQUANTIZE_S64_INPUT_H;
152     input_dims.c = REQUANTIZE_S64_IN_CH;
153     filter_dims.w = REQUANTIZE_S64_FILTER_X;
154     filter_dims.h = REQUANTIZE_S64_FILTER_Y;
155     output_dims.w = REQUANTIZE_S64_OUTPUT_W;
156     output_dims.h = REQUANTIZE_S64_OUTPUT_H;
157     output_dims.c = REQUANTIZE_S64_OUT_CH;
158 
159     conv_params.padding.w = REQUANTIZE_S64_PAD_X;
160     conv_params.padding.h = REQUANTIZE_S64_PAD_Y;
161     conv_params.stride.w = REQUANTIZE_S64_STRIDE_X;
162     conv_params.stride.h = REQUANTIZE_S64_STRIDE_Y;
163     conv_params.dilation.w = REQUANTIZE_S64_DILATION_X;
164     conv_params.dilation.h = REQUANTIZE_S64_DILATION_Y;
165 
166     conv_params.input_offset = REQUANTIZE_S64_INPUT_OFFSET;
167     conv_params.output_offset = REQUANTIZE_S64_OUTPUT_OFFSET;
168     conv_params.activation.min = REQUANTIZE_S64_OUT_ACTIVATION_MIN;
169     conv_params.activation.max = REQUANTIZE_S64_OUT_ACTIVATION_MAX;
170     quant_params.multiplier = (int32_t *)requantize_s64_output_mult;
171     quant_params.shift = (int32_t *)requantize_s64_output_shift;
172 
173     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
174     ctx.buf = malloc(buf_size);
175 
176     arm_cmsis_nn_status result = arm_convolve_s16(&ctx,
177                                                   &conv_params,
178                                                   &quant_params,
179                                                   &input_dims,
180                                                   input_data,
181                                                   &filter_dims,
182                                                   kernel_data,
183                                                   &bias_dims,
184                                                   &bias_data,
185                                                   &output_dims,
186                                                   output);
187     if (ctx.buf)
188     {
189         memset(ctx.buf, 0, buf_size);
190         free(ctx.buf);
191     }
192     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
193     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
194     memset(output, 0, sizeof(output));
195 
196     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
197     ctx.buf = malloc(buf_size);
198 
199     result = arm_convolve_wrapper_s16(&ctx,
200                                       &conv_params,
201                                       &quant_params,
202                                       &input_dims,
203                                       input_data,
204                                       &filter_dims,
205                                       kernel_data,
206                                       &bias_dims,
207                                       &bias_data,
208                                       &output_dims,
209                                       output);
210 
211     if (ctx.buf)
212     {
213         memset(ctx.buf, 0, buf_size);
214         free(ctx.buf);
215     }
216     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
217     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
218 }
219 
int16xint8_dilation_1_arm_convolve_s16(void)220 void int16xint8_dilation_1_arm_convolve_s16(void)
221 {
222     int16_t output[INT16XINT8_DILATION_1_DST_SIZE] = {0};
223 
224     cmsis_nn_context ctx;
225     cmsis_nn_conv_params conv_params;
226     cmsis_nn_per_channel_quant_params quant_params;
227     cmsis_nn_dims input_dims;
228     cmsis_nn_dims filter_dims;
229     cmsis_nn_dims bias_dims;
230     cmsis_nn_dims output_dims;
231 
232     const int64_t *int64_bias_data = int16xint8_dilation_1_biases;
233     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
234     const int8_t *kernel_data = int16xint8_dilation_1_weights;
235     const int16_t *input_data = int16xint8_dilation_1_input;
236     const int16_t *output_ref = int16xint8_dilation_1_output_ref;
237     const int32_t output_ref_size = INT16XINT8_DILATION_1_DST_SIZE;
238 
239     input_dims.n = INT16XINT8_DILATION_1_INPUT_BATCHES;
240     input_dims.w = INT16XINT8_DILATION_1_INPUT_W;
241     input_dims.h = INT16XINT8_DILATION_1_INPUT_H;
242     input_dims.c = INT16XINT8_DILATION_1_IN_CH;
243     filter_dims.w = INT16XINT8_DILATION_1_FILTER_X;
244     filter_dims.h = INT16XINT8_DILATION_1_FILTER_Y;
245     output_dims.w = INT16XINT8_DILATION_1_OUTPUT_W;
246     output_dims.h = INT16XINT8_DILATION_1_OUTPUT_H;
247     output_dims.c = INT16XINT8_DILATION_1_OUT_CH;
248 
249     conv_params.padding.w = INT16XINT8_DILATION_1_PAD_X;
250     conv_params.padding.h = INT16XINT8_DILATION_1_PAD_Y;
251     conv_params.stride.w = INT16XINT8_DILATION_1_STRIDE_X;
252     conv_params.stride.h = INT16XINT8_DILATION_1_STRIDE_Y;
253     conv_params.dilation.w = INT16XINT8_DILATION_1_DILATION_X;
254     conv_params.dilation.h = INT16XINT8_DILATION_1_DILATION_Y;
255 
256     conv_params.input_offset = INT16XINT8_DILATION_1_INPUT_OFFSET;
257     conv_params.output_offset = INT16XINT8_DILATION_1_OUTPUT_OFFSET;
258     conv_params.activation.min = INT16XINT8_DILATION_1_OUT_ACTIVATION_MIN;
259     conv_params.activation.max = INT16XINT8_DILATION_1_OUT_ACTIVATION_MAX;
260     quant_params.multiplier = (int32_t *)int16xint8_dilation_1_output_mult;
261     quant_params.shift = (int32_t *)int16xint8_dilation_1_output_shift;
262 
263     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
264     ctx.buf = malloc(buf_size);
265 
266     arm_cmsis_nn_status result = arm_convolve_s16(&ctx,
267                                                   &conv_params,
268                                                   &quant_params,
269                                                   &input_dims,
270                                                   input_data,
271                                                   &filter_dims,
272                                                   kernel_data,
273                                                   &bias_dims,
274                                                   &bias_data,
275                                                   &output_dims,
276                                                   output);
277     if (ctx.buf)
278     {
279         memset(ctx.buf, 0, buf_size);
280         free(ctx.buf);
281     }
282     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
283     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
284     memset(output, 0, sizeof(output));
285 
286     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
287     ctx.buf = malloc(buf_size);
288 
289     result = arm_convolve_wrapper_s16(&ctx,
290                                       &conv_params,
291                                       &quant_params,
292                                       &input_dims,
293                                       input_data,
294                                       &filter_dims,
295                                       kernel_data,
296                                       &bias_dims,
297                                       &bias_data,
298                                       &output_dims,
299                                       output);
300 
301     if (ctx.buf)
302     {
303         memset(ctx.buf, 0, buf_size);
304         free(ctx.buf);
305     }
306     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
307     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
308 }
309 
int16xint8_dilation_2_arm_convolve_s16(void)310 void int16xint8_dilation_2_arm_convolve_s16(void)
311 {
312     int16_t output[INT16XINT8_DILATION_2_DST_SIZE] = {0};
313 
314     cmsis_nn_context ctx;
315     cmsis_nn_conv_params conv_params;
316     cmsis_nn_per_channel_quant_params quant_params;
317     cmsis_nn_dims input_dims;
318     cmsis_nn_dims filter_dims;
319     cmsis_nn_dims bias_dims;
320     cmsis_nn_dims output_dims;
321 
322     const int64_t *int64_bias_data = int16xint8_dilation_2_biases;
323     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
324     const int8_t *kernel_data = int16xint8_dilation_2_weights;
325     const int16_t *input_data = int16xint8_dilation_2_input;
326     const int16_t *output_ref = int16xint8_dilation_2_output_ref;
327     const int32_t output_ref_size = INT16XINT8_DILATION_2_DST_SIZE;
328 
329     input_dims.n = INT16XINT8_DILATION_2_INPUT_BATCHES;
330     input_dims.w = INT16XINT8_DILATION_2_INPUT_W;
331     input_dims.h = INT16XINT8_DILATION_2_INPUT_H;
332     input_dims.c = INT16XINT8_DILATION_2_IN_CH;
333     filter_dims.w = INT16XINT8_DILATION_2_FILTER_X;
334     filter_dims.h = INT16XINT8_DILATION_2_FILTER_Y;
335     output_dims.w = INT16XINT8_DILATION_2_OUTPUT_W;
336     output_dims.h = INT16XINT8_DILATION_2_OUTPUT_H;
337     output_dims.c = INT16XINT8_DILATION_2_OUT_CH;
338 
339     conv_params.padding.w = INT16XINT8_DILATION_2_PAD_X;
340     conv_params.padding.h = INT16XINT8_DILATION_2_PAD_Y;
341     conv_params.stride.w = INT16XINT8_DILATION_2_STRIDE_X;
342     conv_params.stride.h = INT16XINT8_DILATION_2_STRIDE_Y;
343     conv_params.dilation.w = INT16XINT8_DILATION_2_DILATION_X;
344     conv_params.dilation.h = INT16XINT8_DILATION_2_DILATION_Y;
345 
346     conv_params.input_offset = INT16XINT8_DILATION_2_INPUT_OFFSET;
347     conv_params.output_offset = INT16XINT8_DILATION_2_OUTPUT_OFFSET;
348     conv_params.activation.min = INT16XINT8_DILATION_2_OUT_ACTIVATION_MIN;
349     conv_params.activation.max = INT16XINT8_DILATION_2_OUT_ACTIVATION_MAX;
350     quant_params.multiplier = (int32_t *)int16xint8_dilation_2_output_mult;
351     quant_params.shift = (int32_t *)int16xint8_dilation_2_output_shift;
352 
353     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
354     ctx.buf = malloc(buf_size);
355 
356     arm_cmsis_nn_status result = arm_convolve_s16(&ctx,
357                                                   &conv_params,
358                                                   &quant_params,
359                                                   &input_dims,
360                                                   input_data,
361                                                   &filter_dims,
362                                                   kernel_data,
363                                                   &bias_dims,
364                                                   &bias_data,
365                                                   &output_dims,
366                                                   output);
367     if (ctx.buf)
368     {
369         memset(ctx.buf, 0, buf_size);
370         free(ctx.buf);
371     }
372     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
373     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
374     memset(output, 0, sizeof(output));
375 
376     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
377     ctx.buf = malloc(buf_size);
378 
379     result = arm_convolve_wrapper_s16(&ctx,
380                                       &conv_params,
381                                       &quant_params,
382                                       &input_dims,
383                                       input_data,
384                                       &filter_dims,
385                                       kernel_data,
386                                       &bias_dims,
387                                       &bias_data,
388                                       &output_dims,
389                                       output);
390 
391     if (ctx.buf)
392     {
393         memset(ctx.buf, 0, buf_size);
394         free(ctx.buf);
395     }
396     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
397     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
398 }
399 
int16xint8_dilation_3_arm_convolve_s16(void)400 void int16xint8_dilation_3_arm_convolve_s16(void)
401 {
402     int16_t output[INT16XINT8_DILATION_3_DST_SIZE] = {0};
403 
404     cmsis_nn_context ctx;
405     cmsis_nn_conv_params conv_params;
406     cmsis_nn_per_channel_quant_params quant_params;
407     cmsis_nn_dims input_dims;
408     cmsis_nn_dims filter_dims;
409     cmsis_nn_dims bias_dims;
410     cmsis_nn_dims output_dims;
411 
412     const int64_t *int64_bias_data = int16xint8_dilation_3_biases;
413     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
414     const int8_t *kernel_data = int16xint8_dilation_3_weights;
415     const int16_t *input_data = int16xint8_dilation_3_input;
416     const int16_t *output_ref = int16xint8_dilation_3_output_ref;
417     const int32_t output_ref_size = INT16XINT8_DILATION_3_DST_SIZE;
418 
419     input_dims.n = INT16XINT8_DILATION_3_INPUT_BATCHES;
420     input_dims.w = INT16XINT8_DILATION_3_INPUT_W;
421     input_dims.h = INT16XINT8_DILATION_3_INPUT_H;
422     input_dims.c = INT16XINT8_DILATION_3_IN_CH;
423     filter_dims.w = INT16XINT8_DILATION_3_FILTER_X;
424     filter_dims.h = INT16XINT8_DILATION_3_FILTER_Y;
425     output_dims.w = INT16XINT8_DILATION_3_OUTPUT_W;
426     output_dims.h = INT16XINT8_DILATION_3_OUTPUT_H;
427     output_dims.c = INT16XINT8_DILATION_3_OUT_CH;
428 
429     conv_params.padding.w = INT16XINT8_DILATION_3_PAD_X;
430     conv_params.padding.h = INT16XINT8_DILATION_3_PAD_Y;
431     conv_params.stride.w = INT16XINT8_DILATION_3_STRIDE_X;
432     conv_params.stride.h = INT16XINT8_DILATION_3_STRIDE_Y;
433     conv_params.dilation.w = INT16XINT8_DILATION_3_DILATION_X;
434     conv_params.dilation.h = INT16XINT8_DILATION_3_DILATION_Y;
435 
436     conv_params.input_offset = INT16XINT8_DILATION_3_INPUT_OFFSET;
437     conv_params.output_offset = INT16XINT8_DILATION_3_OUTPUT_OFFSET;
438     conv_params.activation.min = INT16XINT8_DILATION_3_OUT_ACTIVATION_MIN;
439     conv_params.activation.max = INT16XINT8_DILATION_3_OUT_ACTIVATION_MAX;
440     quant_params.multiplier = (int32_t *)int16xint8_dilation_3_output_mult;
441     quant_params.shift = (int32_t *)int16xint8_dilation_3_output_shift;
442 
443     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
444     ctx.buf = malloc(buf_size);
445 
446     arm_cmsis_nn_status result = arm_convolve_s16(&ctx,
447                                                   &conv_params,
448                                                   &quant_params,
449                                                   &input_dims,
450                                                   input_data,
451                                                   &filter_dims,
452                                                   kernel_data,
453                                                   &bias_dims,
454                                                   &bias_data,
455                                                   &output_dims,
456                                                   output);
457     if (ctx.buf)
458     {
459         memset(ctx.buf, 0, buf_size);
460         free(ctx.buf);
461     }
462     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
463     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
464     memset(output, 0, sizeof(output));
465 
466     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
467     ctx.buf = malloc(buf_size);
468 
469     result = arm_convolve_wrapper_s16(&ctx,
470                                       &conv_params,
471                                       &quant_params,
472                                       &input_dims,
473                                       input_data,
474                                       &filter_dims,
475                                       kernel_data,
476                                       &bias_dims,
477                                       &bias_data,
478                                       &output_dims,
479                                       output);
480 
481     if (ctx.buf)
482     {
483         memset(ctx.buf, 0, buf_size);
484         free(ctx.buf);
485     }
486     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
487     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
488 }
489 
buffer_size_arm_convolve_s16(void)490 void buffer_size_arm_convolve_s16(void)
491 {
492     cmsis_nn_conv_params conv_params;
493     cmsis_nn_dims input_dims;
494     cmsis_nn_dims filter_dims;
495     cmsis_nn_dims output_dims;
496 
497     input_dims.n = INT16XINT8_DILATION_3_INPUT_BATCHES;
498     input_dims.w = INT16XINT8_DILATION_3_INPUT_W;
499     input_dims.h = INT16XINT8_DILATION_3_INPUT_H;
500     input_dims.c = INT16XINT8_DILATION_3_IN_CH;
501     filter_dims.w = INT16XINT8_DILATION_3_FILTER_X;
502     filter_dims.h = INT16XINT8_DILATION_3_FILTER_Y;
503     output_dims.w = INT16XINT8_DILATION_3_OUTPUT_W;
504     output_dims.h = INT16XINT8_DILATION_3_OUTPUT_H;
505     output_dims.c = INT16XINT8_DILATION_3_OUT_CH;
506 
507     conv_params.padding.w = INT16XINT8_DILATION_3_PAD_X;
508     conv_params.padding.h = INT16XINT8_DILATION_3_PAD_Y;
509     conv_params.stride.w = INT16XINT8_DILATION_3_STRIDE_X;
510     conv_params.stride.h = INT16XINT8_DILATION_3_STRIDE_Y;
511     conv_params.dilation.w = INT16XINT8_DILATION_3_DILATION_X;
512     conv_params.dilation.h = INT16XINT8_DILATION_3_DILATION_Y;
513 
514     conv_params.input_offset = INT16XINT8_DILATION_3_INPUT_OFFSET;
515     conv_params.output_offset = INT16XINT8_DILATION_3_OUTPUT_OFFSET;
516     conv_params.activation.min = INT16XINT8_DILATION_3_OUT_ACTIVATION_MIN;
517     conv_params.activation.max = INT16XINT8_DILATION_3_OUT_ACTIVATION_MAX;
518 
519     const int32_t buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
520     const int32_t wrapper_buf_size =
521         arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
522 
523     TEST_ASSERT_EQUAL(wrapper_buf_size, buf_size);
524 }
525 
buffer_size_mve_arm_convolve_s16(void)526 void buffer_size_mve_arm_convolve_s16(void)
527 {
528 #if defined(ARM_MATH_MVEI)
529     cmsis_nn_conv_params conv_params;
530     cmsis_nn_dims input_dims;
531     cmsis_nn_dims filter_dims;
532     cmsis_nn_dims output_dims;
533 
534     input_dims.n = INT16XINT8_DILATION_3_INPUT_BATCHES;
535     input_dims.w = INT16XINT8_DILATION_3_INPUT_W;
536     input_dims.h = INT16XINT8_DILATION_3_INPUT_H;
537     input_dims.c = INT16XINT8_DILATION_3_IN_CH;
538     filter_dims.w = INT16XINT8_DILATION_3_FILTER_X;
539     filter_dims.h = INT16XINT8_DILATION_3_FILTER_Y;
540     output_dims.w = INT16XINT8_DILATION_3_OUTPUT_W;
541     output_dims.h = INT16XINT8_DILATION_3_OUTPUT_H;
542     output_dims.c = INT16XINT8_DILATION_3_OUT_CH;
543 
544     conv_params.padding.w = INT16XINT8_DILATION_3_PAD_X;
545     conv_params.padding.h = INT16XINT8_DILATION_3_PAD_Y;
546     conv_params.stride.w = INT16XINT8_DILATION_3_STRIDE_X;
547     conv_params.stride.h = INT16XINT8_DILATION_3_STRIDE_Y;
548     conv_params.dilation.w = INT16XINT8_DILATION_3_DILATION_X;
549     conv_params.dilation.h = INT16XINT8_DILATION_3_DILATION_Y;
550 
551     conv_params.input_offset = INT16XINT8_DILATION_3_INPUT_OFFSET;
552     conv_params.output_offset = INT16XINT8_DILATION_3_OUTPUT_OFFSET;
553     conv_params.activation.min = INT16XINT8_DILATION_3_OUT_ACTIVATION_MIN;
554     conv_params.activation.max = INT16XINT8_DILATION_3_OUT_ACTIVATION_MAX;
555 
556     const int32_t wrapper_buf_size =
557         arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
558     const int32_t mve_wrapper_buf_size =
559         arm_convolve_wrapper_s16_get_buffer_size_mve(&conv_params, &input_dims, &filter_dims, &output_dims);
560 
561     TEST_ASSERT_EQUAL(wrapper_buf_size, mve_wrapper_buf_size);
562 #endif
563 }
564 
buffer_size_dsp_arm_convolve_s16(void)565 void buffer_size_dsp_arm_convolve_s16(void)
566 {
567 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
568     cmsis_nn_conv_params conv_params;
569     cmsis_nn_dims input_dims;
570     cmsis_nn_dims filter_dims;
571     cmsis_nn_dims output_dims;
572 
573     input_dims.n = INT16XINT8_DILATION_3_INPUT_BATCHES;
574     input_dims.w = INT16XINT8_DILATION_3_INPUT_W;
575     input_dims.h = INT16XINT8_DILATION_3_INPUT_H;
576     input_dims.c = INT16XINT8_DILATION_3_IN_CH;
577     filter_dims.w = INT16XINT8_DILATION_3_FILTER_X;
578     filter_dims.h = INT16XINT8_DILATION_3_FILTER_Y;
579     output_dims.w = INT16XINT8_DILATION_3_OUTPUT_W;
580     output_dims.h = INT16XINT8_DILATION_3_OUTPUT_H;
581     output_dims.c = INT16XINT8_DILATION_3_OUT_CH;
582 
583     conv_params.padding.w = INT16XINT8_DILATION_3_PAD_X;
584     conv_params.padding.h = INT16XINT8_DILATION_3_PAD_Y;
585     conv_params.stride.w = INT16XINT8_DILATION_3_STRIDE_X;
586     conv_params.stride.h = INT16XINT8_DILATION_3_STRIDE_Y;
587     conv_params.dilation.w = INT16XINT8_DILATION_3_DILATION_X;
588     conv_params.dilation.h = INT16XINT8_DILATION_3_DILATION_Y;
589 
590     conv_params.input_offset = INT16XINT8_DILATION_3_INPUT_OFFSET;
591     conv_params.output_offset = INT16XINT8_DILATION_3_OUTPUT_OFFSET;
592     conv_params.activation.min = INT16XINT8_DILATION_3_OUT_ACTIVATION_MIN;
593     conv_params.activation.max = INT16XINT8_DILATION_3_OUT_ACTIVATION_MAX;
594 
595     const int32_t wrapper_buf_size =
596         arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
597     const int32_t dsp_wrapper_buf_size =
598         arm_convolve_wrapper_s16_get_buffer_size_dsp(&conv_params, &input_dims, &filter_dims, &output_dims);
599 
600     TEST_ASSERT_EQUAL(wrapper_buf_size, dsp_wrapper_buf_size);
601 #endif
602 }
603 
int16xint8_spill_arm_convolve_s16(void)604 void int16xint8_spill_arm_convolve_s16(void)
605 {
606     int16_t output[INT16XINT8_SPILL_DST_SIZE] = {0};
607 
608     cmsis_nn_context ctx;
609     cmsis_nn_conv_params conv_params;
610     cmsis_nn_per_channel_quant_params quant_params;
611     cmsis_nn_dims input_dims;
612     cmsis_nn_dims filter_dims;
613     cmsis_nn_dims bias_dims;
614     cmsis_nn_dims output_dims;
615 
616     const int64_t *int64_bias_data = int16xint8_spill_biases;
617     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
618     const int8_t *kernel_data = int16xint8_spill_weights;
619     const int16_t *input_data = int16xint8_spill_input;
620     const int16_t *output_ref = int16xint8_spill_output_ref;
621     const int32_t output_ref_size = INT16XINT8_SPILL_DST_SIZE;
622 
623     input_dims.n = INT16XINT8_SPILL_INPUT_BATCHES;
624     input_dims.w = INT16XINT8_SPILL_INPUT_W;
625     input_dims.h = INT16XINT8_SPILL_INPUT_H;
626     input_dims.c = INT16XINT8_SPILL_IN_CH;
627     filter_dims.w = INT16XINT8_SPILL_FILTER_X;
628     filter_dims.h = INT16XINT8_SPILL_FILTER_Y;
629     output_dims.w = INT16XINT8_SPILL_OUTPUT_W;
630     output_dims.h = INT16XINT8_SPILL_OUTPUT_H;
631     output_dims.c = INT16XINT8_SPILL_OUT_CH;
632 
633     conv_params.padding.w = INT16XINT8_SPILL_PAD_X;
634     conv_params.padding.h = INT16XINT8_SPILL_PAD_Y;
635     conv_params.stride.w = INT16XINT8_SPILL_STRIDE_X;
636     conv_params.stride.h = INT16XINT8_SPILL_STRIDE_Y;
637     conv_params.dilation.w = INT16XINT8_SPILL_DILATION_X;
638     conv_params.dilation.h = INT16XINT8_SPILL_DILATION_Y;
639 
640     conv_params.input_offset = 0;
641     conv_params.output_offset = 0;
642     conv_params.activation.min = INT16XINT8_SPILL_OUT_ACTIVATION_MIN;
643     conv_params.activation.max = INT16XINT8_SPILL_OUT_ACTIVATION_MAX;
644     quant_params.multiplier = (int32_t *)int16xint8_spill_output_mult;
645     quant_params.shift = (int32_t *)int16xint8_spill_output_shift;
646 
647     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
648     ctx.buf = malloc(buf_size);
649     arm_cmsis_nn_status result;
650     result = arm_convolve_s16(&ctx,
651                               &conv_params,
652                               &quant_params,
653                               &input_dims,
654                               input_data,
655                               &filter_dims,
656                               kernel_data,
657                               &bias_dims,
658                               &bias_data,
659                               &output_dims,
660                               output);
661     if (ctx.buf)
662     {
663         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
664         memset(ctx.buf, 0, buf_size);
665         free(ctx.buf);
666     }
667     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
668     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
669     memset(output, 0, sizeof(output));
670 
671     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
672     ctx.buf = malloc(buf_size);
673 
674     result = arm_convolve_wrapper_s16(&ctx,
675                                       &conv_params,
676                                       &quant_params,
677                                       &input_dims,
678                                       input_data,
679                                       &filter_dims,
680                                       kernel_data,
681                                       &bias_dims,
682                                       &bias_data,
683                                       &output_dims,
684                                       output);
685     if (ctx.buf)
686     {
687         memset(ctx.buf, 0, buf_size);
688         free(ctx.buf);
689     }
690     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
691     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
692 }
693 
int16xint8_spill2_arm_convolve_s16(void)694 void int16xint8_spill2_arm_convolve_s16(void)
695 {
696     int16_t output[INT16XINT8_SPILL2_DST_SIZE] = {0};
697 
698     cmsis_nn_context ctx;
699     cmsis_nn_conv_params conv_params;
700     cmsis_nn_per_channel_quant_params quant_params;
701     cmsis_nn_dims input_dims;
702     cmsis_nn_dims filter_dims;
703     cmsis_nn_dims bias_dims;
704     cmsis_nn_dims output_dims;
705 
706     const int64_t *int64_bias_data = int16xint8_spill2_biases;
707     const cmsis_nn_bias_data bias_data = {int64_bias_data, false};
708     const int8_t *kernel_data = int16xint8_spill2_weights;
709     const int16_t *input_data = int16xint8_spill2_input;
710     const int16_t *output_ref = int16xint8_spill2_output_ref;
711     const int32_t output_ref_size = INT16XINT8_SPILL2_DST_SIZE;
712 
713     input_dims.n = INT16XINT8_SPILL2_INPUT_BATCHES;
714     input_dims.w = INT16XINT8_SPILL2_INPUT_W;
715     input_dims.h = INT16XINT8_SPILL2_INPUT_H;
716     input_dims.c = INT16XINT8_SPILL2_IN_CH;
717     filter_dims.w = INT16XINT8_SPILL2_FILTER_X;
718     filter_dims.h = INT16XINT8_SPILL2_FILTER_Y;
719     output_dims.w = INT16XINT8_SPILL2_OUTPUT_W;
720     output_dims.h = INT16XINT8_SPILL2_OUTPUT_H;
721     output_dims.c = INT16XINT8_SPILL2_OUT_CH;
722 
723     conv_params.padding.w = INT16XINT8_SPILL2_PAD_X;
724     conv_params.padding.h = INT16XINT8_SPILL2_PAD_Y;
725     conv_params.stride.w = INT16XINT8_SPILL2_STRIDE_X;
726     conv_params.stride.h = INT16XINT8_SPILL2_STRIDE_Y;
727     conv_params.dilation.w = INT16XINT8_SPILL2_DILATION_X;
728     conv_params.dilation.h = INT16XINT8_SPILL2_DILATION_Y;
729 
730     conv_params.input_offset = 0;
731     conv_params.output_offset = 0;
732     conv_params.activation.min = INT16XINT8_SPILL2_OUT_ACTIVATION_MIN;
733     conv_params.activation.max = INT16XINT8_SPILL2_OUT_ACTIVATION_MAX;
734     quant_params.multiplier = (int32_t *)int16xint8_spill2_output_mult;
735     quant_params.shift = (int32_t *)int16xint8_spill2_output_shift;
736 
737     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
738     ctx.buf = malloc(buf_size);
739     arm_cmsis_nn_status result;
740     result = arm_convolve_s16(&ctx,
741                               &conv_params,
742                               &quant_params,
743                               &input_dims,
744                               input_data,
745                               &filter_dims,
746                               kernel_data,
747                               &bias_dims,
748                               &bias_data,
749                               &output_dims,
750                               output);
751     if (ctx.buf)
752     {
753         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
754         memset(ctx.buf, 0, buf_size);
755         free(ctx.buf);
756     }
757     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
758     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
759     memset(output, 0, sizeof(output));
760 
761     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
762     ctx.buf = malloc(buf_size);
763 
764     result = arm_convolve_wrapper_s16(&ctx,
765                                       &conv_params,
766                                       &quant_params,
767                                       &input_dims,
768                                       input_data,
769                                       &filter_dims,
770                                       kernel_data,
771                                       &bias_dims,
772                                       &bias_data,
773                                       &output_dims,
774                                       output);
775     if (ctx.buf)
776     {
777         memset(ctx.buf, 0, buf_size);
778         free(ctx.buf);
779     }
780     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
781     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
782 }
783 
int16xint8xint32_1_arm_convolve_s16(void)784 void int16xint8xint32_1_arm_convolve_s16(void)
785 {
786     int16_t output[INT16XINT8XINT32_1_DST_SIZE] = {0};
787 
788     cmsis_nn_context ctx;
789     cmsis_nn_conv_params conv_params;
790     cmsis_nn_per_channel_quant_params quant_params;
791     cmsis_nn_dims input_dims;
792     cmsis_nn_dims filter_dims;
793     cmsis_nn_dims bias_dims;
794     cmsis_nn_dims output_dims;
795 
796     const int32_t *int32_bias_data = int16xint8xint32_1_biases;
797     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
798     const int8_t *kernel_data = int16xint8xint32_1_weights;
799     const int16_t *input_data = int16xint8xint32_1_input;
800     const int16_t *output_ref = int16xint8xint32_1_output_ref;
801     const int32_t output_ref_size = INT16XINT8XINT32_1_DST_SIZE;
802 
803     input_dims.n = INT16XINT8XINT32_1_INPUT_BATCHES;
804     input_dims.w = INT16XINT8XINT32_1_INPUT_W;
805     input_dims.h = INT16XINT8XINT32_1_INPUT_H;
806     input_dims.c = INT16XINT8XINT32_1_IN_CH;
807     filter_dims.w = INT16XINT8XINT32_1_FILTER_X;
808     filter_dims.h = INT16XINT8XINT32_1_FILTER_Y;
809     output_dims.w = INT16XINT8XINT32_1_OUTPUT_W;
810     output_dims.h = INT16XINT8XINT32_1_OUTPUT_H;
811     output_dims.c = INT16XINT8XINT32_1_OUT_CH;
812 
813     conv_params.padding.w = INT16XINT8XINT32_1_PAD_X;
814     conv_params.padding.h = INT16XINT8XINT32_1_PAD_Y;
815     conv_params.stride.w = INT16XINT8XINT32_1_STRIDE_X;
816     conv_params.stride.h = INT16XINT8XINT32_1_STRIDE_Y;
817     conv_params.dilation.w = INT16XINT8XINT32_1_DILATION_X;
818     conv_params.dilation.h = INT16XINT8XINT32_1_DILATION_Y;
819 
820     conv_params.input_offset = 0;
821     conv_params.output_offset = 0;
822     conv_params.activation.min = INT16XINT8XINT32_1_OUT_ACTIVATION_MIN;
823     conv_params.activation.max = INT16XINT8XINT32_1_OUT_ACTIVATION_MAX;
824     quant_params.multiplier = (int32_t *)int16xint8xint32_1_output_mult;
825     quant_params.shift = (int32_t *)int16xint8xint32_1_output_shift;
826 
827     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
828     ctx.buf = malloc(buf_size);
829     arm_cmsis_nn_status result;
830     result = arm_convolve_s16(&ctx,
831                               &conv_params,
832                               &quant_params,
833                               &input_dims,
834                               input_data,
835                               &filter_dims,
836                               kernel_data,
837                               &bias_dims,
838                               &bias_data,
839                               &output_dims,
840                               output);
841     if (ctx.buf)
842     {
843         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
844         memset(ctx.buf, 0, buf_size);
845         free(ctx.buf);
846     }
847     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
848     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
849     memset(output, 0, sizeof(output));
850 
851     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
852     ctx.buf = malloc(buf_size);
853 
854     result = arm_convolve_wrapper_s16(&ctx,
855                                       &conv_params,
856                                       &quant_params,
857                                       &input_dims,
858                                       input_data,
859                                       &filter_dims,
860                                       kernel_data,
861                                       &bias_dims,
862                                       &bias_data,
863                                       &output_dims,
864                                       output);
865     if (ctx.buf)
866     {
867         memset(ctx.buf, 0, buf_size);
868         free(ctx.buf);
869     }
870     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
871     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
872 }
873 
int16xint8xint32_2_arm_convolve_s16(void)874 void int16xint8xint32_2_arm_convolve_s16(void)
875 {
876     int16_t output[INT16XINT8XINT32_2_DST_SIZE] = {0};
877 
878     cmsis_nn_context ctx;
879     cmsis_nn_conv_params conv_params;
880     cmsis_nn_per_channel_quant_params quant_params;
881     cmsis_nn_dims input_dims;
882     cmsis_nn_dims filter_dims;
883     cmsis_nn_dims bias_dims;
884     cmsis_nn_dims output_dims;
885 
886     const int32_t *int32_bias_data = int16xint8xint32_2_biases;
887     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
888     const int8_t *kernel_data = int16xint8xint32_2_weights;
889     const int16_t *input_data = int16xint8xint32_2_input;
890     const int16_t *output_ref = int16xint8xint32_2_output_ref;
891     const int32_t output_ref_size = INT16XINT8XINT32_2_DST_SIZE;
892 
893     input_dims.n = INT16XINT8XINT32_2_INPUT_BATCHES;
894     input_dims.w = INT16XINT8XINT32_2_INPUT_W;
895     input_dims.h = INT16XINT8XINT32_2_INPUT_H;
896     input_dims.c = INT16XINT8XINT32_2_IN_CH;
897     filter_dims.w = INT16XINT8XINT32_2_FILTER_X;
898     filter_dims.h = INT16XINT8XINT32_2_FILTER_Y;
899     output_dims.w = INT16XINT8XINT32_2_OUTPUT_W;
900     output_dims.h = INT16XINT8XINT32_2_OUTPUT_H;
901     output_dims.c = INT16XINT8XINT32_2_OUT_CH;
902 
903     conv_params.padding.w = INT16XINT8XINT32_2_PAD_X;
904     conv_params.padding.h = INT16XINT8XINT32_2_PAD_Y;
905     conv_params.stride.w = INT16XINT8XINT32_2_STRIDE_X;
906     conv_params.stride.h = INT16XINT8XINT32_2_STRIDE_Y;
907     conv_params.dilation.w = INT16XINT8XINT32_2_DILATION_X;
908     conv_params.dilation.h = INT16XINT8XINT32_2_DILATION_Y;
909 
910     conv_params.input_offset = 0;
911     conv_params.output_offset = 0;
912     conv_params.activation.min = INT16XINT8XINT32_2_OUT_ACTIVATION_MIN;
913     conv_params.activation.max = INT16XINT8XINT32_2_OUT_ACTIVATION_MAX;
914     quant_params.multiplier = (int32_t *)int16xint8xint32_2_output_mult;
915     quant_params.shift = (int32_t *)int16xint8xint32_2_output_shift;
916 
917     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
918     ctx.buf = malloc(buf_size);
919     arm_cmsis_nn_status result;
920     result = arm_convolve_s16(&ctx,
921                               &conv_params,
922                               &quant_params,
923                               &input_dims,
924                               input_data,
925                               &filter_dims,
926                               kernel_data,
927                               &bias_dims,
928                               &bias_data,
929                               &output_dims,
930                               output);
931     if (ctx.buf)
932     {
933         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
934         memset(ctx.buf, 0, buf_size);
935         free(ctx.buf);
936     }
937     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
938     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
939     memset(output, 0, sizeof(output));
940 
941     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
942     ctx.buf = malloc(buf_size);
943 
944     result = arm_convolve_wrapper_s16(&ctx,
945                                       &conv_params,
946                                       &quant_params,
947                                       &input_dims,
948                                       input_data,
949                                       &filter_dims,
950                                       kernel_data,
951                                       &bias_dims,
952                                       &bias_data,
953                                       &output_dims,
954                                       output);
955     if (ctx.buf)
956     {
957         memset(ctx.buf, 0, buf_size);
958         free(ctx.buf);
959     }
960     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
961     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
962 }
963 
int16xint8xint32_3_arm_convolve_s16(void)964 void int16xint8xint32_3_arm_convolve_s16(void)
965 {
966     int16_t output[INT16XINT8XINT32_3_DST_SIZE] = {0};
967 
968     cmsis_nn_context ctx;
969     cmsis_nn_conv_params conv_params;
970     cmsis_nn_per_channel_quant_params quant_params;
971     cmsis_nn_dims input_dims;
972     cmsis_nn_dims filter_dims;
973     cmsis_nn_dims bias_dims;
974     cmsis_nn_dims output_dims;
975 
976     const int32_t *int32_bias_data = int16xint8xint32_3_biases;
977     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
978     const int8_t *kernel_data = int16xint8xint32_3_weights;
979     const int16_t *input_data = int16xint8xint32_3_input;
980     const int16_t *output_ref = int16xint8xint32_3_output_ref;
981     const int32_t output_ref_size = INT16XINT8XINT32_3_DST_SIZE;
982 
983     input_dims.n = INT16XINT8XINT32_3_INPUT_BATCHES;
984     input_dims.w = INT16XINT8XINT32_3_INPUT_W;
985     input_dims.h = INT16XINT8XINT32_3_INPUT_H;
986     input_dims.c = INT16XINT8XINT32_3_IN_CH;
987     filter_dims.w = INT16XINT8XINT32_3_FILTER_X;
988     filter_dims.h = INT16XINT8XINT32_3_FILTER_Y;
989     output_dims.w = INT16XINT8XINT32_3_OUTPUT_W;
990     output_dims.h = INT16XINT8XINT32_3_OUTPUT_H;
991     output_dims.c = INT16XINT8XINT32_3_OUT_CH;
992 
993     conv_params.padding.w = INT16XINT8XINT32_3_PAD_X;
994     conv_params.padding.h = INT16XINT8XINT32_3_PAD_Y;
995     conv_params.stride.w = INT16XINT8XINT32_3_STRIDE_X;
996     conv_params.stride.h = INT16XINT8XINT32_3_STRIDE_Y;
997     conv_params.dilation.w = INT16XINT8XINT32_3_DILATION_X;
998     conv_params.dilation.h = INT16XINT8XINT32_3_DILATION_Y;
999 
1000     conv_params.input_offset = 0;
1001     conv_params.output_offset = 0;
1002     conv_params.activation.min = INT16XINT8XINT32_3_OUT_ACTIVATION_MIN;
1003     conv_params.activation.max = INT16XINT8XINT32_3_OUT_ACTIVATION_MAX;
1004     quant_params.multiplier = (int32_t *)int16xint8xint32_3_output_mult;
1005     quant_params.shift = (int32_t *)int16xint8xint32_3_output_shift;
1006 
1007     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
1008     ctx.buf = malloc(buf_size);
1009     arm_cmsis_nn_status result;
1010     result = arm_convolve_s16(&ctx,
1011                               &conv_params,
1012                               &quant_params,
1013                               &input_dims,
1014                               input_data,
1015                               &filter_dims,
1016                               kernel_data,
1017                               &bias_dims,
1018                               &bias_data,
1019                               &output_dims,
1020                               output);
1021     if (ctx.buf)
1022     {
1023         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
1024         memset(ctx.buf, 0, buf_size);
1025         free(ctx.buf);
1026     }
1027     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1028     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1029     memset(output, 0, sizeof(output));
1030 
1031     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1032     ctx.buf = malloc(buf_size);
1033 
1034     result = arm_convolve_wrapper_s16(&ctx,
1035                                       &conv_params,
1036                                       &quant_params,
1037                                       &input_dims,
1038                                       input_data,
1039                                       &filter_dims,
1040                                       kernel_data,
1041                                       &bias_dims,
1042                                       &bias_data,
1043                                       &output_dims,
1044                                       output);
1045     if (ctx.buf)
1046     {
1047         memset(ctx.buf, 0, buf_size);
1048         free(ctx.buf);
1049     }
1050     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1051     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1052 }
1053 
int16xint8xint32_4_arm_convolve_s16(void)1054 void int16xint8xint32_4_arm_convolve_s16(void)
1055 {
1056     int16_t output[INT16XINT8XINT32_4_DST_SIZE] = {0};
1057 
1058     cmsis_nn_context ctx;
1059     cmsis_nn_conv_params conv_params;
1060     cmsis_nn_per_channel_quant_params quant_params;
1061     cmsis_nn_dims input_dims;
1062     cmsis_nn_dims filter_dims;
1063     cmsis_nn_dims bias_dims;
1064     cmsis_nn_dims output_dims;
1065 
1066     const int32_t *int32_bias_data = int16xint8xint32_4_biases;
1067     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
1068     const int8_t *kernel_data = int16xint8xint32_4_weights;
1069     const int16_t *input_data = int16xint8xint32_4_input;
1070     const int16_t *output_ref = int16xint8xint32_4_output_ref;
1071     const int32_t output_ref_size = INT16XINT8XINT32_4_DST_SIZE;
1072 
1073     input_dims.n = INT16XINT8XINT32_4_INPUT_BATCHES;
1074     input_dims.w = INT16XINT8XINT32_4_INPUT_W;
1075     input_dims.h = INT16XINT8XINT32_4_INPUT_H;
1076     input_dims.c = INT16XINT8XINT32_4_IN_CH;
1077     filter_dims.w = INT16XINT8XINT32_4_FILTER_X;
1078     filter_dims.h = INT16XINT8XINT32_4_FILTER_Y;
1079     output_dims.w = INT16XINT8XINT32_4_OUTPUT_W;
1080     output_dims.h = INT16XINT8XINT32_4_OUTPUT_H;
1081     output_dims.c = INT16XINT8XINT32_4_OUT_CH;
1082 
1083     conv_params.padding.w = INT16XINT8XINT32_4_PAD_X;
1084     conv_params.padding.h = INT16XINT8XINT32_4_PAD_Y;
1085     conv_params.stride.w = INT16XINT8XINT32_4_STRIDE_X;
1086     conv_params.stride.h = INT16XINT8XINT32_4_STRIDE_Y;
1087     conv_params.dilation.w = INT16XINT8XINT32_4_DILATION_X;
1088     conv_params.dilation.h = INT16XINT8XINT32_4_DILATION_Y;
1089 
1090     conv_params.input_offset = 0;
1091     conv_params.output_offset = 0;
1092     conv_params.activation.min = INT16XINT8XINT32_4_OUT_ACTIVATION_MIN;
1093     conv_params.activation.max = INT16XINT8XINT32_4_OUT_ACTIVATION_MAX;
1094     quant_params.multiplier = (int32_t *)int16xint8xint32_4_output_mult;
1095     quant_params.shift = (int32_t *)int16xint8xint32_4_output_shift;
1096 
1097     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
1098     ctx.buf = malloc(buf_size);
1099     arm_cmsis_nn_status result;
1100     result = arm_convolve_s16(&ctx,
1101                               &conv_params,
1102                               &quant_params,
1103                               &input_dims,
1104                               input_data,
1105                               &filter_dims,
1106                               kernel_data,
1107                               &bias_dims,
1108                               &bias_data,
1109                               &output_dims,
1110                               output);
1111     if (ctx.buf)
1112     {
1113         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
1114         memset(ctx.buf, 0, buf_size);
1115         free(ctx.buf);
1116     }
1117     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1118     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1119     memset(output, 0, sizeof(output));
1120 
1121     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1122     ctx.buf = malloc(buf_size);
1123 
1124     result = arm_convolve_wrapper_s16(&ctx,
1125                                       &conv_params,
1126                                       &quant_params,
1127                                       &input_dims,
1128                                       input_data,
1129                                       &filter_dims,
1130                                       kernel_data,
1131                                       &bias_dims,
1132                                       &bias_data,
1133                                       &output_dims,
1134                                       output);
1135     if (ctx.buf)
1136     {
1137         memset(ctx.buf, 0, buf_size);
1138         free(ctx.buf);
1139     }
1140     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1141     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1142 }
1143 
int16xint8xint32_5_arm_convolve_s16(void)1144 void int16xint8xint32_5_arm_convolve_s16(void)
1145 {
1146     int16_t output[INT16XINT8XINT32_5_DST_SIZE] = {0};
1147 
1148     cmsis_nn_context ctx;
1149     cmsis_nn_conv_params conv_params;
1150     cmsis_nn_per_channel_quant_params quant_params;
1151     cmsis_nn_dims input_dims;
1152     cmsis_nn_dims filter_dims;
1153     cmsis_nn_dims bias_dims;
1154     cmsis_nn_dims output_dims;
1155 
1156     const int32_t *int32_bias_data = int16xint8xint32_5_biases;
1157     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
1158     const int8_t *kernel_data = int16xint8xint32_5_weights;
1159     const int16_t *input_data = int16xint8xint32_5_input;
1160     const int16_t *output_ref = int16xint8xint32_5_output_ref;
1161     const int32_t output_ref_size = INT16XINT8XINT32_5_DST_SIZE;
1162 
1163     input_dims.n = INT16XINT8XINT32_5_INPUT_BATCHES;
1164     input_dims.w = INT16XINT8XINT32_5_INPUT_W;
1165     input_dims.h = INT16XINT8XINT32_5_INPUT_H;
1166     input_dims.c = INT16XINT8XINT32_5_IN_CH;
1167     filter_dims.w = INT16XINT8XINT32_5_FILTER_X;
1168     filter_dims.h = INT16XINT8XINT32_5_FILTER_Y;
1169     output_dims.w = INT16XINT8XINT32_5_OUTPUT_W;
1170     output_dims.h = INT16XINT8XINT32_5_OUTPUT_H;
1171     output_dims.c = INT16XINT8XINT32_5_OUT_CH;
1172 
1173     conv_params.padding.w = INT16XINT8XINT32_5_PAD_X;
1174     conv_params.padding.h = INT16XINT8XINT32_5_PAD_Y;
1175     conv_params.stride.w = INT16XINT8XINT32_5_STRIDE_X;
1176     conv_params.stride.h = INT16XINT8XINT32_5_STRIDE_Y;
1177     conv_params.dilation.w = INT16XINT8XINT32_5_DILATION_X;
1178     conv_params.dilation.h = INT16XINT8XINT32_5_DILATION_Y;
1179 
1180     conv_params.input_offset = 0;
1181     conv_params.output_offset = 0;
1182     conv_params.activation.min = INT16XINT8XINT32_5_OUT_ACTIVATION_MIN;
1183     conv_params.activation.max = INT16XINT8XINT32_5_OUT_ACTIVATION_MAX;
1184     quant_params.multiplier = (int32_t *)int16xint8xint32_5_output_mult;
1185     quant_params.shift = (int32_t *)int16xint8xint32_5_output_shift;
1186 
1187     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
1188     ctx.buf = malloc(buf_size);
1189     arm_cmsis_nn_status result;
1190     result = arm_convolve_s16(&ctx,
1191                               &conv_params,
1192                               &quant_params,
1193                               &input_dims,
1194                               input_data,
1195                               &filter_dims,
1196                               kernel_data,
1197                               &bias_dims,
1198                               &bias_data,
1199                               &output_dims,
1200                               output);
1201     if (ctx.buf)
1202     {
1203         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
1204         memset(ctx.buf, 0, buf_size);
1205         free(ctx.buf);
1206     }
1207     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1208     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1209     memset(output, 0, sizeof(output));
1210 
1211     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1212     ctx.buf = malloc(buf_size);
1213 
1214     result = arm_convolve_wrapper_s16(&ctx,
1215                                       &conv_params,
1216                                       &quant_params,
1217                                       &input_dims,
1218                                       input_data,
1219                                       &filter_dims,
1220                                       kernel_data,
1221                                       &bias_dims,
1222                                       &bias_data,
1223                                       &output_dims,
1224                                       output);
1225     if (ctx.buf)
1226     {
1227         memset(ctx.buf, 0, buf_size);
1228         free(ctx.buf);
1229     }
1230     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1231     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1232 }
1233 
int16xint8xint32_6_arm_convolve_s16(void)1234 void int16xint8xint32_6_arm_convolve_s16(void)
1235 {
1236     int16_t output[INT16XINT8XINT32_6_DST_SIZE] = {0};
1237 
1238     cmsis_nn_context ctx;
1239     cmsis_nn_conv_params conv_params;
1240     cmsis_nn_per_channel_quant_params quant_params;
1241     cmsis_nn_dims input_dims;
1242     cmsis_nn_dims filter_dims;
1243     cmsis_nn_dims bias_dims;
1244     cmsis_nn_dims output_dims;
1245 
1246     const int32_t *int32_bias_data = int16xint8xint32_6_biases;
1247     const cmsis_nn_bias_data bias_data = {int32_bias_data, true};
1248     const int8_t *kernel_data = int16xint8xint32_6_weights;
1249     const int16_t *input_data = int16xint8xint32_6_input;
1250     const int16_t *output_ref = int16xint8xint32_6_output_ref;
1251     const int32_t output_ref_size = INT16XINT8XINT32_6_DST_SIZE;
1252 
1253     input_dims.n = INT16XINT8XINT32_6_INPUT_BATCHES;
1254     input_dims.w = INT16XINT8XINT32_6_INPUT_W;
1255     input_dims.h = INT16XINT8XINT32_6_INPUT_H;
1256     input_dims.c = INT16XINT8XINT32_6_IN_CH;
1257     filter_dims.w = INT16XINT8XINT32_6_FILTER_X;
1258     filter_dims.h = INT16XINT8XINT32_6_FILTER_Y;
1259     output_dims.w = INT16XINT8XINT32_6_OUTPUT_W;
1260     output_dims.h = INT16XINT8XINT32_6_OUTPUT_H;
1261     output_dims.c = INT16XINT8XINT32_6_OUT_CH;
1262 
1263     conv_params.padding.w = INT16XINT8XINT32_6_PAD_X;
1264     conv_params.padding.h = INT16XINT8XINT32_6_PAD_Y;
1265     conv_params.stride.w = INT16XINT8XINT32_6_STRIDE_X;
1266     conv_params.stride.h = INT16XINT8XINT32_6_STRIDE_Y;
1267     conv_params.dilation.w = INT16XINT8XINT32_6_DILATION_X;
1268     conv_params.dilation.h = INT16XINT8XINT32_6_DILATION_Y;
1269 
1270     conv_params.input_offset = 0;
1271     conv_params.output_offset = 0;
1272     conv_params.activation.min = INT16XINT8XINT32_6_OUT_ACTIVATION_MIN;
1273     conv_params.activation.max = INT16XINT8XINT32_6_OUT_ACTIVATION_MAX;
1274     quant_params.multiplier = (int32_t *)int16xint8xint32_6_output_mult;
1275     quant_params.shift = (int32_t *)int16xint8xint32_6_output_shift;
1276 
1277     int buf_size = arm_convolve_s16_get_buffer_size(&input_dims, &filter_dims);
1278     ctx.buf = malloc(buf_size);
1279     arm_cmsis_nn_status result;
1280     result = arm_convolve_s16(&ctx,
1281                               &conv_params,
1282                               &quant_params,
1283                               &input_dims,
1284                               input_data,
1285                               &filter_dims,
1286                               kernel_data,
1287                               &bias_dims,
1288                               &bias_data,
1289                               &output_dims,
1290                               output);
1291     if (ctx.buf)
1292     {
1293         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
1294         memset(ctx.buf, 0, buf_size);
1295         free(ctx.buf);
1296     }
1297     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1298     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1299     memset(output, 0, sizeof(output));
1300 
1301     buf_size = arm_convolve_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1302     ctx.buf = malloc(buf_size);
1303 
1304     result = arm_convolve_wrapper_s16(&ctx,
1305                                       &conv_params,
1306                                       &quant_params,
1307                                       &input_dims,
1308                                       input_data,
1309                                       &filter_dims,
1310                                       kernel_data,
1311                                       &bias_dims,
1312                                       &bias_data,
1313                                       &output_dims,
1314                                       output);
1315     if (ctx.buf)
1316     {
1317         memset(ctx.buf, 0, buf_size);
1318         free(ctx.buf);
1319     }
1320     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1321     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
1322 }
1323