1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <stdlib.h>
20 
21 #include <arm_nnfunctions.h>
22 #include <unity.h>
23 
24 #include "../TestData/basic_2_int4/test_data.h"
25 #include "../TestData/basic_int4/test_data.h"
26 #include "../TestData/conv_1_x_n_1_int4/test_data.h"
27 #include "../TestData/conv_1_x_n_2_int4/test_data.h"
28 #include "../TestData/conv_1_x_n_3_int4/test_data.h"
29 #include "../TestData/conv_1_x_n_4_int4/test_data.h"
30 #include "../TestData/conv_1_x_n_5_int4/test_data.h"
31 #include "../TestData/conv_2_int4/test_data.h"
32 #include "../TestData/conv_2x2_dilation_5x5_input_int4/test_data.h"
33 #include "../TestData/conv_2x2_dilation_int4/test_data.h"
34 #include "../TestData/conv_2x3_dilation_int4/test_data.h"
35 #include "../TestData/conv_3_int4/test_data.h"
36 #include "../TestData/conv_3x2_dilation_int4/test_data.h"
37 #include "../TestData/conv_3x3_dilation_5x5_input_int4/test_data.h"
38 #include "../TestData/conv_4_int4/test_data.h"
39 #include "../TestData/conv_5_int4/test_data.h"
40 #include "../TestData/conv_dilation_golden_int4/test_data.h"
41 #include "../TestData/conv_out_activation_int4/test_data.h"
42 #include "../TestData/stride2pad1_int4/test_data.h"
43 #include "../Utils/validate.h"
44 
basic_arm_convolve_s4(void)45 void basic_arm_convolve_s4(void)
46 {
47     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
48     int8_t output[BASIC_INT4_DST_SIZE] = {0};
49 
50     cmsis_nn_context ctx;
51     cmsis_nn_conv_params conv_params;
52     cmsis_nn_per_channel_quant_params quant_params;
53     cmsis_nn_dims input_dims;
54     cmsis_nn_dims filter_dims;
55     cmsis_nn_dims bias_dims;
56     cmsis_nn_dims output_dims;
57 
58     const int32_t *bias_data = basic_int4_biases;
59     const int8_t *kernel_data = basic_int4_weights;
60     const int8_t *input_data = basic_int4_input;
61     const int8_t *output_ref = basic_int4_output_ref;
62     const int32_t output_ref_size = BASIC_INT4_DST_SIZE;
63 
64     input_dims.n = BASIC_INT4_INPUT_BATCHES;
65     input_dims.w = BASIC_INT4_INPUT_W;
66     input_dims.h = BASIC_INT4_INPUT_H;
67     input_dims.c = BASIC_INT4_IN_CH;
68     filter_dims.w = BASIC_INT4_FILTER_X;
69     filter_dims.h = BASIC_INT4_FILTER_Y;
70     output_dims.w = BASIC_INT4_OUTPUT_W;
71     output_dims.h = BASIC_INT4_OUTPUT_H;
72     output_dims.c = BASIC_INT4_OUT_CH;
73 
74     conv_params.padding.w = BASIC_INT4_PAD_X;
75     conv_params.padding.h = BASIC_INT4_PAD_Y;
76     conv_params.stride.w = BASIC_INT4_STRIDE_X;
77     conv_params.stride.h = BASIC_INT4_STRIDE_Y;
78     conv_params.dilation.w = BASIC_INT4_DILATION_X;
79     conv_params.dilation.h = BASIC_INT4_DILATION_Y;
80 
81     conv_params.input_offset = BASIC_INT4_INPUT_OFFSET;
82     conv_params.output_offset = BASIC_INT4_OUTPUT_OFFSET;
83     conv_params.activation.min = BASIC_INT4_OUT_ACTIVATION_MIN;
84     conv_params.activation.max = BASIC_INT4_OUT_ACTIVATION_MAX;
85     quant_params.multiplier = (int32_t *)basic_int4_output_mult;
86     quant_params.shift = (int32_t *)basic_int4_output_shift;
87 
88     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
89     ctx.buf = malloc(buf_size);
90     ctx.size = 0;
91 
92     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
93                                                  &conv_params,
94                                                  &quant_params,
95                                                  &input_dims,
96                                                  input_data,
97                                                  &filter_dims,
98                                                  kernel_data,
99                                                  &bias_dims,
100                                                  bias_data,
101                                                  &output_dims,
102                                                  output);
103 
104     if (ctx.buf)
105     {
106         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
107         memset(ctx.buf, 0, buf_size);
108         free(ctx.buf);
109     }
110     TEST_ASSERT_EQUAL(expected, result);
111     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
112     memset(output, 0, sizeof(output));
113 
114     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
115     ctx.buf = malloc(buf_size);
116     ctx.size = 0;
117 
118     result = arm_convolve_wrapper_s4(&ctx,
119                                      &conv_params,
120                                      &quant_params,
121                                      &input_dims,
122                                      input_data,
123                                      &filter_dims,
124                                      kernel_data,
125                                      &bias_dims,
126                                      bias_data,
127                                      &output_dims,
128                                      output);
129 
130     if (ctx.buf)
131     {
132         memset(ctx.buf, 0, buf_size);
133         free(ctx.buf);
134     }
135     TEST_ASSERT_EQUAL(expected, result);
136     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
137 }
138 
basic_2_arm_convolve_s4(void)139 void basic_2_arm_convolve_s4(void)
140 {
141     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
142     int8_t output[BASIC_2_INT4_DST_SIZE] = {0};
143 
144     cmsis_nn_context ctx;
145     cmsis_nn_conv_params conv_params;
146     cmsis_nn_per_channel_quant_params quant_params;
147     cmsis_nn_dims input_dims;
148     cmsis_nn_dims filter_dims;
149     cmsis_nn_dims bias_dims;
150     cmsis_nn_dims output_dims;
151 
152     const int32_t *bias_data = basic_2_int4_biases;
153     const int8_t *kernel_data = basic_2_int4_weights;
154     const int8_t *input_data = basic_2_int4_input;
155     const int8_t *output_ref = basic_2_int4_output_ref;
156     const int32_t output_ref_size = BASIC_2_INT4_DST_SIZE;
157 
158     input_dims.n = BASIC_2_INT4_INPUT_BATCHES;
159     input_dims.w = BASIC_2_INT4_INPUT_W;
160     input_dims.h = BASIC_2_INT4_INPUT_H;
161     input_dims.c = BASIC_2_INT4_IN_CH;
162     filter_dims.w = BASIC_2_INT4_FILTER_X;
163     filter_dims.h = BASIC_2_INT4_FILTER_Y;
164     output_dims.w = BASIC_2_INT4_OUTPUT_W;
165     output_dims.h = BASIC_2_INT4_OUTPUT_H;
166     output_dims.c = BASIC_2_INT4_OUT_CH;
167 
168     conv_params.padding.w = BASIC_2_INT4_PAD_X;
169     conv_params.padding.h = BASIC_2_INT4_PAD_Y;
170     conv_params.stride.w = BASIC_2_INT4_STRIDE_X;
171     conv_params.stride.h = BASIC_2_INT4_STRIDE_Y;
172     conv_params.dilation.w = BASIC_2_INT4_DILATION_X;
173     conv_params.dilation.h = BASIC_2_INT4_DILATION_Y;
174 
175     conv_params.input_offset = BASIC_2_INT4_INPUT_OFFSET;
176     conv_params.output_offset = BASIC_2_INT4_OUTPUT_OFFSET;
177     conv_params.activation.min = BASIC_2_INT4_OUT_ACTIVATION_MIN;
178     conv_params.activation.max = BASIC_2_INT4_OUT_ACTIVATION_MAX;
179     quant_params.multiplier = (int32_t *)basic_2_int4_output_mult;
180     quant_params.shift = (int32_t *)basic_2_int4_output_shift;
181 
182     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
183     ctx.buf = malloc(buf_size);
184     ctx.size = 0;
185 
186     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
187                                                  &conv_params,
188                                                  &quant_params,
189                                                  &input_dims,
190                                                  input_data,
191                                                  &filter_dims,
192                                                  kernel_data,
193                                                  &bias_dims,
194                                                  bias_data,
195                                                  &output_dims,
196                                                  output);
197 
198     if (ctx.buf)
199     {
200         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
201         memset(ctx.buf, 0, buf_size);
202         free(ctx.buf);
203     }
204     TEST_ASSERT_EQUAL(expected, result);
205     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
206     memset(output, 0, sizeof(output));
207 
208     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
209     ctx.buf = malloc(buf_size);
210     ctx.size = 0;
211 
212     result = arm_convolve_wrapper_s4(&ctx,
213                                      &conv_params,
214                                      &quant_params,
215                                      &input_dims,
216                                      input_data,
217                                      &filter_dims,
218                                      kernel_data,
219                                      &bias_dims,
220                                      bias_data,
221                                      &output_dims,
222                                      output);
223 
224     if (ctx.buf)
225     {
226         memset(ctx.buf, 0, buf_size);
227         free(ctx.buf);
228     }
229     TEST_ASSERT_EQUAL(expected, result);
230     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
231 }
232 
stride2pad1_arm_convolve_s4(void)233 void stride2pad1_arm_convolve_s4(void)
234 {
235     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
236     int8_t output[STRIDE2PAD1_INT4_DST_SIZE] = {0};
237 
238     cmsis_nn_context ctx;
239     cmsis_nn_conv_params conv_params;
240     cmsis_nn_per_channel_quant_params quant_params;
241     cmsis_nn_dims input_dims;
242     cmsis_nn_dims filter_dims;
243     cmsis_nn_dims bias_dims;
244     cmsis_nn_dims output_dims;
245 
246     const int32_t *bias_data = stride2pad1_int4_biases;
247     const int8_t *kernel_data = stride2pad1_int4_weights;
248     const int8_t *input_data = stride2pad1_int4_input;
249     const int8_t *output_ref = stride2pad1_int4_output_ref;
250     const int32_t output_ref_size = STRIDE2PAD1_INT4_DST_SIZE;
251 
252     input_dims.n = STRIDE2PAD1_INT4_INPUT_BATCHES;
253     input_dims.w = STRIDE2PAD1_INT4_INPUT_W;
254     input_dims.h = STRIDE2PAD1_INT4_INPUT_H;
255     input_dims.c = STRIDE2PAD1_INT4_IN_CH;
256     filter_dims.w = STRIDE2PAD1_INT4_FILTER_X;
257     filter_dims.h = STRIDE2PAD1_INT4_FILTER_Y;
258     output_dims.w = STRIDE2PAD1_INT4_OUTPUT_W;
259     output_dims.h = STRIDE2PAD1_INT4_OUTPUT_H;
260     output_dims.c = STRIDE2PAD1_INT4_OUT_CH;
261 
262     conv_params.padding.w = STRIDE2PAD1_INT4_PAD_X;
263     conv_params.padding.h = STRIDE2PAD1_INT4_PAD_Y;
264     conv_params.stride.w = STRIDE2PAD1_INT4_STRIDE_X;
265     conv_params.stride.h = STRIDE2PAD1_INT4_STRIDE_Y;
266     conv_params.dilation.w = STRIDE2PAD1_INT4_DILATION_X;
267     conv_params.dilation.h = STRIDE2PAD1_INT4_DILATION_Y;
268 
269     conv_params.input_offset = STRIDE2PAD1_INT4_INPUT_OFFSET;
270     conv_params.output_offset = STRIDE2PAD1_INT4_OUTPUT_OFFSET;
271     conv_params.activation.min = STRIDE2PAD1_INT4_OUT_ACTIVATION_MIN;
272     conv_params.activation.max = STRIDE2PAD1_INT4_OUT_ACTIVATION_MAX;
273     quant_params.multiplier = (int32_t *)stride2pad1_int4_output_mult;
274     quant_params.shift = (int32_t *)stride2pad1_int4_output_shift;
275 
276     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
277     ctx.buf = malloc(buf_size);
278     ctx.size = 0;
279 
280     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
281                                                  &conv_params,
282                                                  &quant_params,
283                                                  &input_dims,
284                                                  input_data,
285                                                  &filter_dims,
286                                                  kernel_data,
287                                                  &bias_dims,
288                                                  bias_data,
289                                                  &output_dims,
290                                                  output);
291 
292     if (ctx.buf)
293     {
294         memset(ctx.buf, 0, buf_size);
295         free(ctx.buf);
296     }
297     TEST_ASSERT_EQUAL(expected, result);
298     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
299     memset(output, 0, sizeof(output));
300 
301     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
302     ctx.buf = malloc(buf_size);
303     ctx.size = 0;
304 
305     result = arm_convolve_wrapper_s4(&ctx,
306                                      &conv_params,
307                                      &quant_params,
308                                      &input_dims,
309                                      input_data,
310                                      &filter_dims,
311                                      kernel_data,
312                                      &bias_dims,
313                                      bias_data,
314                                      &output_dims,
315                                      output);
316 
317     if (ctx.buf)
318     {
319         memset(ctx.buf, 0, buf_size);
320         free(ctx.buf);
321     }
322     TEST_ASSERT_EQUAL(expected, result);
323     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
324 }
325 
conv_2_arm_convolve_s4(void)326 void conv_2_arm_convolve_s4(void)
327 {
328     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
329     int8_t output[CONV_2_INT4_DST_SIZE] = {0};
330 
331     cmsis_nn_context ctx;
332     cmsis_nn_conv_params conv_params;
333     cmsis_nn_per_channel_quant_params quant_params;
334     cmsis_nn_dims input_dims;
335     cmsis_nn_dims filter_dims;
336     cmsis_nn_dims bias_dims;
337     cmsis_nn_dims output_dims;
338 
339     const int32_t *bias_data = conv_2_int4_biases;
340     const int8_t *kernel_data = conv_2_int4_weights;
341     const int8_t *input_data = conv_2_int4_input;
342     const int8_t *output_ref = conv_2_int4_output_ref;
343     const int32_t output_ref_size = CONV_2_INT4_DST_SIZE;
344 
345     input_dims.n = CONV_2_INT4_INPUT_BATCHES;
346     input_dims.w = CONV_2_INT4_INPUT_W;
347     input_dims.h = CONV_2_INT4_INPUT_H;
348     input_dims.c = CONV_2_INT4_IN_CH;
349     filter_dims.w = CONV_2_INT4_FILTER_X;
350     filter_dims.h = CONV_2_INT4_FILTER_Y;
351     output_dims.w = CONV_2_INT4_OUTPUT_W;
352     output_dims.h = CONV_2_INT4_OUTPUT_H;
353     output_dims.c = CONV_2_INT4_OUT_CH;
354 
355     conv_params.padding.w = CONV_2_INT4_PAD_X;
356     conv_params.padding.h = CONV_2_INT4_PAD_Y;
357     conv_params.stride.w = CONV_2_INT4_STRIDE_X;
358     conv_params.stride.h = CONV_2_INT4_STRIDE_Y;
359     conv_params.dilation.w = CONV_2_INT4_DILATION_X;
360     conv_params.dilation.h = CONV_2_INT4_DILATION_Y;
361 
362     conv_params.input_offset = CONV_2_INT4_INPUT_OFFSET;
363     conv_params.output_offset = CONV_2_INT4_OUTPUT_OFFSET;
364     conv_params.activation.min = CONV_2_INT4_OUT_ACTIVATION_MIN;
365     conv_params.activation.max = CONV_2_INT4_OUT_ACTIVATION_MAX;
366     quant_params.multiplier = (int32_t *)conv_2_int4_output_mult;
367     quant_params.shift = (int32_t *)conv_2_int4_output_shift;
368 
369     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
370     ctx.buf = malloc(buf_size);
371     ctx.size = 0;
372 
373     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
374                                                  &conv_params,
375                                                  &quant_params,
376                                                  &input_dims,
377                                                  input_data,
378                                                  &filter_dims,
379                                                  conv_2_int4_weights,
380                                                  &bias_dims,
381                                                  bias_data,
382                                                  &output_dims,
383                                                  output);
384 
385     if (ctx.buf)
386     {
387         memset(ctx.buf, 0, buf_size);
388         free(ctx.buf);
389     }
390     TEST_ASSERT_EQUAL(expected, result);
391     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
392     memset(output, 0, sizeof(output));
393 
394     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
395     ctx.buf = malloc(buf_size);
396     ctx.size = 0;
397 
398     result = arm_convolve_wrapper_s4(&ctx,
399                                      &conv_params,
400                                      &quant_params,
401                                      &input_dims,
402                                      input_data,
403                                      &filter_dims,
404                                      kernel_data,
405                                      &bias_dims,
406                                      bias_data,
407                                      &output_dims,
408                                      output);
409 
410     if (ctx.buf)
411     {
412         memset(ctx.buf, 0, buf_size);
413         free(ctx.buf);
414     }
415     TEST_ASSERT_EQUAL(expected, result);
416     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
417 }
418 
conv_3_arm_convolve_s4(void)419 void conv_3_arm_convolve_s4(void)
420 {
421     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
422     int8_t output[CONV_3_INT4_DST_SIZE] = {0};
423 
424     cmsis_nn_context ctx;
425     cmsis_nn_conv_params conv_params;
426     cmsis_nn_per_channel_quant_params quant_params;
427     cmsis_nn_dims input_dims;
428     cmsis_nn_dims filter_dims;
429     cmsis_nn_dims bias_dims;
430     cmsis_nn_dims output_dims;
431 
432     const int32_t *bias_data = conv_3_int4_biases;
433     const int8_t *kernel_data = conv_3_int4_weights;
434     const int8_t *input_data = conv_3_int4_input;
435     const int8_t *output_ref = conv_3_int4_output_ref;
436     const int32_t output_ref_size = CONV_3_INT4_DST_SIZE;
437 
438     input_dims.n = CONV_3_INT4_INPUT_BATCHES;
439     input_dims.w = CONV_3_INT4_INPUT_W;
440     input_dims.h = CONV_3_INT4_INPUT_H;
441     input_dims.c = CONV_3_INT4_IN_CH;
442     filter_dims.w = CONV_3_INT4_FILTER_X;
443     filter_dims.h = CONV_3_INT4_FILTER_Y;
444     output_dims.w = CONV_3_INT4_OUTPUT_W;
445     output_dims.h = CONV_3_INT4_OUTPUT_H;
446     output_dims.c = CONV_3_INT4_OUT_CH;
447 
448     conv_params.padding.w = CONV_3_INT4_PAD_X;
449     conv_params.padding.h = CONV_3_INT4_PAD_Y;
450     conv_params.stride.w = CONV_3_INT4_STRIDE_X;
451     conv_params.stride.h = CONV_3_INT4_STRIDE_Y;
452     conv_params.dilation.w = CONV_3_INT4_DILATION_X;
453     conv_params.dilation.h = CONV_3_INT4_DILATION_Y;
454 
455     conv_params.input_offset = CONV_3_INT4_INPUT_OFFSET;
456     conv_params.output_offset = CONV_3_INT4_OUTPUT_OFFSET;
457     conv_params.activation.min = CONV_3_INT4_OUT_ACTIVATION_MIN;
458     conv_params.activation.max = CONV_3_INT4_OUT_ACTIVATION_MAX;
459     quant_params.multiplier = (int32_t *)conv_3_int4_output_mult;
460     quant_params.shift = (int32_t *)conv_3_int4_output_shift;
461 
462     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
463     ctx.buf = malloc(buf_size);
464     ctx.size = 0;
465 
466     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
467                                                  &conv_params,
468                                                  &quant_params,
469                                                  &input_dims,
470                                                  input_data,
471                                                  &filter_dims,
472                                                  conv_3_int4_weights,
473                                                  &bias_dims,
474                                                  bias_data,
475                                                  &output_dims,
476                                                  output);
477 
478     if (ctx.buf)
479     {
480         memset(ctx.buf, 0, buf_size);
481         free(ctx.buf);
482     }
483     TEST_ASSERT_EQUAL(expected, result);
484     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
485     memset(output, 0, sizeof(output));
486 
487     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
488     ctx.buf = malloc(buf_size);
489     ctx.size = 0;
490 
491     result = arm_convolve_wrapper_s4(&ctx,
492                                      &conv_params,
493                                      &quant_params,
494                                      &input_dims,
495                                      input_data,
496                                      &filter_dims,
497                                      kernel_data,
498                                      &bias_dims,
499                                      bias_data,
500                                      &output_dims,
501                                      output);
502 
503     if (ctx.buf)
504     {
505         memset(ctx.buf, 0, buf_size);
506         free(ctx.buf);
507     }
508     TEST_ASSERT_EQUAL(expected, result);
509     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
510 }
511 
conv_4_arm_convolve_s4(void)512 void conv_4_arm_convolve_s4(void)
513 {
514     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
515     int8_t output[CONV_4_INT4_DST_SIZE] = {0};
516 
517     cmsis_nn_context ctx;
518     cmsis_nn_conv_params conv_params;
519     cmsis_nn_per_channel_quant_params quant_params;
520     cmsis_nn_dims input_dims;
521     cmsis_nn_dims filter_dims;
522     cmsis_nn_dims bias_dims;
523     cmsis_nn_dims output_dims;
524 
525     const int32_t *bias_data = conv_4_int4_biases;
526     const int8_t *kernel_data = conv_4_int4_weights;
527     const int8_t *input_data = conv_4_int4_input;
528     const int8_t *output_ref = conv_4_int4_output_ref;
529     const int32_t output_ref_size = CONV_4_INT4_DST_SIZE;
530 
531     input_dims.n = CONV_4_INT4_INPUT_BATCHES;
532     input_dims.w = CONV_4_INT4_INPUT_W;
533     input_dims.h = CONV_4_INT4_INPUT_H;
534     input_dims.c = CONV_4_INT4_IN_CH;
535     filter_dims.w = CONV_4_INT4_FILTER_X;
536     filter_dims.h = CONV_4_INT4_FILTER_Y;
537     output_dims.w = CONV_4_INT4_OUTPUT_W;
538     output_dims.h = CONV_4_INT4_OUTPUT_H;
539     output_dims.c = CONV_4_INT4_OUT_CH;
540 
541     conv_params.padding.w = CONV_4_INT4_PAD_X;
542     conv_params.padding.h = CONV_4_INT4_PAD_Y;
543     conv_params.stride.w = CONV_4_INT4_STRIDE_X;
544     conv_params.stride.h = CONV_4_INT4_STRIDE_Y;
545     conv_params.dilation.w = CONV_4_INT4_DILATION_X;
546     conv_params.dilation.h = CONV_4_INT4_DILATION_Y;
547 
548     conv_params.input_offset = CONV_4_INT4_INPUT_OFFSET;
549     conv_params.output_offset = CONV_4_INT4_OUTPUT_OFFSET;
550     conv_params.activation.min = CONV_4_INT4_OUT_ACTIVATION_MIN;
551     conv_params.activation.max = CONV_4_INT4_OUT_ACTIVATION_MAX;
552     quant_params.multiplier = (int32_t *)conv_4_int4_output_mult;
553     quant_params.shift = (int32_t *)conv_4_int4_output_shift;
554 
555     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
556     ctx.buf = malloc(buf_size);
557     ctx.size = 0;
558 
559     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
560                                                  &conv_params,
561                                                  &quant_params,
562                                                  &input_dims,
563                                                  input_data,
564                                                  &filter_dims,
565                                                  conv_4_int4_weights,
566                                                  &bias_dims,
567                                                  bias_data,
568                                                  &output_dims,
569                                                  output);
570 
571     if (ctx.buf)
572     {
573         memset(ctx.buf, 0, buf_size);
574         free(ctx.buf);
575     }
576     TEST_ASSERT_EQUAL(expected, result);
577     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
578     memset(output, 0, sizeof(output));
579 
580     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
581     ctx.buf = malloc(buf_size);
582     ctx.size = 0;
583 
584     result = arm_convolve_wrapper_s4(&ctx,
585                                      &conv_params,
586                                      &quant_params,
587                                      &input_dims,
588                                      input_data,
589                                      &filter_dims,
590                                      kernel_data,
591                                      &bias_dims,
592                                      bias_data,
593                                      &output_dims,
594                                      output);
595 
596     if (ctx.buf)
597     {
598         memset(ctx.buf, 0, buf_size);
599         free(ctx.buf);
600     }
601     TEST_ASSERT_EQUAL(expected, result);
602     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
603 }
604 
conv_out_activation_arm_convolve_s4(void)605 void conv_out_activation_arm_convolve_s4(void)
606 {
607     int8_t output[CONV_OUT_ACTIVATION_INT4_DST_SIZE] = {0};
608 
609     cmsis_nn_context ctx;
610     cmsis_nn_conv_params conv_params;
611     cmsis_nn_per_channel_quant_params quant_params;
612     cmsis_nn_dims input_dims;
613     cmsis_nn_dims filter_dims;
614     cmsis_nn_dims bias_dims;
615     cmsis_nn_dims output_dims;
616 
617     const int32_t *bias_data = conv_out_activation_int4_biases;
618     const int8_t *kernel_data = conv_out_activation_int4_weights;
619     const int8_t *input_data = conv_out_activation_int4_input;
620     const int8_t *output_ref = conv_out_activation_int4_output_ref;
621     const int32_t output_ref_size = CONV_OUT_ACTIVATION_INT4_DST_SIZE;
622 
623     input_dims.n = CONV_OUT_ACTIVATION_INT4_INPUT_BATCHES;
624     input_dims.w = CONV_OUT_ACTIVATION_INT4_INPUT_W;
625     input_dims.h = CONV_OUT_ACTIVATION_INT4_INPUT_H;
626     input_dims.c = CONV_OUT_ACTIVATION_INT4_IN_CH;
627     filter_dims.w = CONV_OUT_ACTIVATION_INT4_FILTER_X;
628     filter_dims.h = CONV_OUT_ACTIVATION_INT4_FILTER_Y;
629     output_dims.w = CONV_OUT_ACTIVATION_INT4_OUTPUT_W;
630     output_dims.h = CONV_OUT_ACTIVATION_INT4_OUTPUT_H;
631     output_dims.c = CONV_OUT_ACTIVATION_INT4_OUT_CH;
632 
633     conv_params.padding.w = CONV_OUT_ACTIVATION_INT4_PAD_X;
634     conv_params.padding.h = CONV_OUT_ACTIVATION_INT4_PAD_Y;
635     conv_params.stride.w = CONV_OUT_ACTIVATION_INT4_STRIDE_X;
636     conv_params.stride.h = CONV_OUT_ACTIVATION_INT4_STRIDE_Y;
637     conv_params.dilation.w = CONV_OUT_ACTIVATION_INT4_DILATION_X;
638     conv_params.dilation.h = CONV_OUT_ACTIVATION_INT4_DILATION_Y;
639 
640     conv_params.input_offset = CONV_OUT_ACTIVATION_INT4_INPUT_OFFSET;
641     conv_params.output_offset = CONV_OUT_ACTIVATION_INT4_OUTPUT_OFFSET;
642     conv_params.activation.min = CONV_OUT_ACTIVATION_INT4_OUT_ACTIVATION_MIN;
643     conv_params.activation.max = CONV_OUT_ACTIVATION_INT4_OUT_ACTIVATION_MAX;
644     quant_params.multiplier = (int32_t *)conv_out_activation_int4_output_mult;
645     quant_params.shift = (int32_t *)conv_out_activation_int4_output_shift;
646 
647     int32_t buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
648     ctx.buf = malloc(buf_size);
649 
650     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
651                                                          &conv_params,
652                                                          &quant_params,
653                                                          &input_dims,
654                                                          input_data,
655                                                          &filter_dims,
656                                                          kernel_data,
657                                                          &bias_dims,
658                                                          bias_data,
659                                                          &output_dims,
660                                                          output);
661     if (ctx.buf)
662     {
663         memset(ctx.buf, 0, buf_size);
664         free(ctx.buf);
665     }
666     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
667     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
668 }
669 
conv_2x2_dilation_arm_convolve_s4(void)670 void conv_2x2_dilation_arm_convolve_s4(void)
671 {
672     int8_t output[CONV_2X2_DILATION_INT4_DST_SIZE] = {0};
673 
674     cmsis_nn_context ctx;
675     cmsis_nn_conv_params conv_params;
676     cmsis_nn_per_channel_quant_params quant_params;
677     cmsis_nn_dims input_dims;
678     cmsis_nn_dims filter_dims;
679     cmsis_nn_dims bias_dims;
680     cmsis_nn_dims output_dims;
681 
682     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
683     const int32_t *bias_data = conv_2x2_dilation_int4_biases;
684     const int8_t *kernel_data = conv_2x2_dilation_int4_weights;
685     const int8_t *input_data = conv_2x2_dilation_int4_input;
686     const int8_t *output_ref = conv_2x2_dilation_int4_output_ref;
687     const int32_t output_ref_size = CONV_2X2_DILATION_INT4_DST_SIZE;
688 
689     input_dims.n = CONV_2X2_DILATION_INT4_INPUT_BATCHES;
690     input_dims.w = CONV_2X2_DILATION_INT4_INPUT_W;
691     input_dims.h = CONV_2X2_DILATION_INT4_INPUT_H;
692     input_dims.c = CONV_2X2_DILATION_INT4_IN_CH;
693     filter_dims.w = CONV_2X2_DILATION_INT4_FILTER_X;
694     filter_dims.h = CONV_2X2_DILATION_INT4_FILTER_Y;
695     output_dims.w = CONV_2X2_DILATION_INT4_OUTPUT_W;
696     output_dims.h = CONV_2X2_DILATION_INT4_OUTPUT_H;
697     output_dims.c = CONV_2X2_DILATION_INT4_OUT_CH;
698 
699     conv_params.padding.w = CONV_2X2_DILATION_INT4_PAD_X;
700     conv_params.padding.h = CONV_2X2_DILATION_INT4_PAD_Y;
701     conv_params.stride.w = CONV_2X2_DILATION_INT4_STRIDE_X;
702     conv_params.stride.h = CONV_2X2_DILATION_INT4_STRIDE_Y;
703     conv_params.dilation.w = CONV_2X2_DILATION_INT4_DILATION_X;
704     conv_params.dilation.h = CONV_2X2_DILATION_INT4_DILATION_Y;
705 
706     conv_params.input_offset = CONV_2X2_DILATION_INT4_INPUT_OFFSET;
707     conv_params.output_offset = CONV_2X2_DILATION_INT4_OUTPUT_OFFSET;
708     conv_params.activation.min = CONV_2X2_DILATION_INT4_OUT_ACTIVATION_MIN;
709     conv_params.activation.max = CONV_2X2_DILATION_INT4_OUT_ACTIVATION_MAX;
710     quant_params.multiplier = (int32_t *)conv_2x2_dilation_int4_output_mult;
711     quant_params.shift = (int32_t *)conv_2x2_dilation_int4_output_shift;
712 
713     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
714     ctx.buf = malloc(buf_size);
715     ctx.size = 0;
716 
717     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
718                                                  &conv_params,
719                                                  &quant_params,
720                                                  &input_dims,
721                                                  input_data,
722                                                  &filter_dims,
723                                                  kernel_data,
724                                                  &bias_dims,
725                                                  bias_data,
726                                                  &output_dims,
727                                                  output);
728 
729     if (ctx.buf)
730     {
731         memset(ctx.buf, 0, buf_size);
732         free(ctx.buf);
733     }
734     TEST_ASSERT_EQUAL(expected, result);
735     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
736     memset(output, 0, sizeof(output));
737 
738     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
739     ctx.buf = malloc(buf_size);
740     ctx.size = 0;
741 
742     result = arm_convolve_wrapper_s4(&ctx,
743                                      &conv_params,
744                                      &quant_params,
745                                      &input_dims,
746                                      input_data,
747                                      &filter_dims,
748                                      kernel_data,
749                                      &bias_dims,
750                                      bias_data,
751                                      &output_dims,
752                                      output);
753 
754     if (ctx.buf)
755     {
756         memset(ctx.buf, 0, buf_size);
757         free(ctx.buf);
758     }
759     TEST_ASSERT_EQUAL(expected, result);
760     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
761 }
762 
conv_2x2_dilation_5x5_input_arm_convolve_s4(void)763 void conv_2x2_dilation_5x5_input_arm_convolve_s4(void)
764 {
765     int8_t output[CONV_2X2_DILATION_5X5_INPUT_INT4_DST_SIZE] = {0};
766 
767     cmsis_nn_context ctx;
768     cmsis_nn_conv_params conv_params;
769     cmsis_nn_per_channel_quant_params quant_params;
770     cmsis_nn_dims input_dims;
771     cmsis_nn_dims filter_dims;
772     cmsis_nn_dims bias_dims;
773     cmsis_nn_dims output_dims;
774 
775     const int32_t *bias_data = conv_2x2_dilation_5x5_input_int4_biases;
776     const int8_t *kernel_data = conv_2x2_dilation_5x5_input_int4_weights;
777     const int8_t *input_data = conv_2x2_dilation_5x5_input_int4_input;
778     const int8_t *output_ref = conv_2x2_dilation_5x5_input_int4_output_ref;
779     const int32_t output_ref_size = CONV_2X2_DILATION_5X5_INPUT_INT4_DST_SIZE;
780     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
781 
782     input_dims.n = CONV_2X2_DILATION_5X5_INPUT_INT4_INPUT_BATCHES;
783     input_dims.w = CONV_2X2_DILATION_5X5_INPUT_INT4_INPUT_W;
784     input_dims.h = CONV_2X2_DILATION_5X5_INPUT_INT4_INPUT_H;
785     input_dims.c = CONV_2X2_DILATION_5X5_INPUT_INT4_IN_CH;
786     filter_dims.w = CONV_2X2_DILATION_5X5_INPUT_INT4_FILTER_X;
787     filter_dims.h = CONV_2X2_DILATION_5X5_INPUT_INT4_FILTER_Y;
788     output_dims.w = CONV_2X2_DILATION_5X5_INPUT_INT4_OUTPUT_W;
789     output_dims.h = CONV_2X2_DILATION_5X5_INPUT_INT4_OUTPUT_H;
790     output_dims.c = CONV_2X2_DILATION_5X5_INPUT_INT4_OUT_CH;
791 
792     conv_params.padding.w = CONV_2X2_DILATION_5X5_INPUT_INT4_PAD_X;
793     conv_params.padding.h = CONV_2X2_DILATION_5X5_INPUT_INT4_PAD_Y;
794     conv_params.stride.w = CONV_2X2_DILATION_5X5_INPUT_INT4_STRIDE_X;
795     conv_params.stride.h = CONV_2X2_DILATION_5X5_INPUT_INT4_STRIDE_Y;
796     conv_params.dilation.w = CONV_2X2_DILATION_5X5_INPUT_INT4_DILATION_X;
797     conv_params.dilation.h = CONV_2X2_DILATION_5X5_INPUT_INT4_DILATION_Y;
798 
799     conv_params.input_offset = CONV_2X2_DILATION_5X5_INPUT_INT4_INPUT_OFFSET;
800     conv_params.output_offset = CONV_2X2_DILATION_5X5_INPUT_INT4_OUTPUT_OFFSET;
801     conv_params.activation.min = CONV_2X2_DILATION_5X5_INPUT_INT4_OUT_ACTIVATION_MIN;
802     conv_params.activation.max = CONV_2X2_DILATION_5X5_INPUT_INT4_OUT_ACTIVATION_MAX;
803     quant_params.multiplier = (int32_t *)conv_2x2_dilation_5x5_input_int4_output_mult;
804     quant_params.shift = (int32_t *)conv_2x2_dilation_5x5_input_int4_output_shift;
805 
806     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
807     ctx.buf = malloc(buf_size);
808 
809     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
810                                                  &conv_params,
811                                                  &quant_params,
812                                                  &input_dims,
813                                                  input_data,
814                                                  &filter_dims,
815                                                  kernel_data,
816                                                  &bias_dims,
817                                                  bias_data,
818                                                  &output_dims,
819                                                  output);
820     if (ctx.buf)
821     {
822         memset(ctx.buf, 0, buf_size);
823         free(ctx.buf);
824     }
825     TEST_ASSERT_EQUAL(expected, result);
826     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
827     memset(output, 0, sizeof(output));
828 
829     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
830     ctx.buf = malloc(buf_size);
831     ctx.size = 0;
832 
833     result = arm_convolve_wrapper_s4(&ctx,
834                                      &conv_params,
835                                      &quant_params,
836                                      &input_dims,
837                                      input_data,
838                                      &filter_dims,
839                                      kernel_data,
840                                      &bias_dims,
841                                      bias_data,
842                                      &output_dims,
843                                      output);
844 
845     if (ctx.buf)
846     {
847         memset(ctx.buf, 0, buf_size);
848         free(ctx.buf);
849     }
850     TEST_ASSERT_EQUAL(expected, result);
851     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
852 }
853 
conv_3x3_dilation_5x5_input_arm_convolve_s4(void)854 void conv_3x3_dilation_5x5_input_arm_convolve_s4(void)
855 {
856     int8_t output[CONV_3X3_DILATION_5X5_INPUT_INT4_DST_SIZE] = {0};
857 
858     cmsis_nn_context ctx;
859     cmsis_nn_conv_params conv_params;
860     cmsis_nn_per_channel_quant_params quant_params;
861     cmsis_nn_dims input_dims;
862     cmsis_nn_dims filter_dims;
863     cmsis_nn_dims bias_dims;
864     cmsis_nn_dims output_dims;
865 
866     const int32_t *bias_data = conv_3x3_dilation_5x5_input_int4_biases;
867     const int8_t *kernel_data = conv_3x3_dilation_5x5_input_int4_weights;
868     const int8_t *input_data = conv_3x3_dilation_5x5_input_int4_input;
869     const int8_t *output_ref = conv_3x3_dilation_5x5_input_int4_output_ref;
870     const int32_t output_ref_size = CONV_3X3_DILATION_5X5_INPUT_INT4_DST_SIZE;
871     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
872 
873     input_dims.n = CONV_3X3_DILATION_5X5_INPUT_INT4_INPUT_BATCHES;
874     input_dims.w = CONV_3X3_DILATION_5X5_INPUT_INT4_INPUT_W;
875     input_dims.h = CONV_3X3_DILATION_5X5_INPUT_INT4_INPUT_H;
876     input_dims.c = CONV_3X3_DILATION_5X5_INPUT_INT4_IN_CH;
877     filter_dims.w = CONV_3X3_DILATION_5X5_INPUT_INT4_FILTER_X;
878     filter_dims.h = CONV_3X3_DILATION_5X5_INPUT_INT4_FILTER_Y;
879     output_dims.w = CONV_3X3_DILATION_5X5_INPUT_INT4_OUTPUT_W;
880     output_dims.h = CONV_3X3_DILATION_5X5_INPUT_INT4_OUTPUT_H;
881     output_dims.c = CONV_3X3_DILATION_5X5_INPUT_INT4_OUT_CH;
882 
883     conv_params.padding.w = CONV_3X3_DILATION_5X5_INPUT_INT4_PAD_X;
884     conv_params.padding.h = CONV_3X3_DILATION_5X5_INPUT_INT4_PAD_Y;
885     conv_params.stride.w = CONV_3X3_DILATION_5X5_INPUT_INT4_STRIDE_X;
886     conv_params.stride.h = CONV_3X3_DILATION_5X5_INPUT_INT4_STRIDE_Y;
887     conv_params.dilation.w = CONV_3X3_DILATION_5X5_INPUT_INT4_DILATION_X;
888     conv_params.dilation.h = CONV_3X3_DILATION_5X5_INPUT_INT4_DILATION_Y;
889 
890     conv_params.input_offset = CONV_3X3_DILATION_5X5_INPUT_INT4_INPUT_OFFSET;
891     conv_params.output_offset = CONV_3X3_DILATION_5X5_INPUT_INT4_OUTPUT_OFFSET;
892     conv_params.activation.min = CONV_3X3_DILATION_5X5_INPUT_INT4_OUT_ACTIVATION_MIN;
893     conv_params.activation.max = CONV_3X3_DILATION_5X5_INPUT_INT4_OUT_ACTIVATION_MAX;
894     quant_params.multiplier = (int32_t *)conv_3x3_dilation_5x5_input_int4_output_mult;
895     quant_params.shift = (int32_t *)conv_3x3_dilation_5x5_input_int4_output_shift;
896 
897     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
898     ctx.buf = malloc(buf_size);
899 
900     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
901                                                  &conv_params,
902                                                  &quant_params,
903                                                  &input_dims,
904                                                  input_data,
905                                                  &filter_dims,
906                                                  kernel_data,
907                                                  &bias_dims,
908                                                  bias_data,
909                                                  &output_dims,
910                                                  output);
911     if (ctx.buf)
912     {
913         memset(ctx.buf, 0, buf_size);
914         free(ctx.buf);
915     }
916     TEST_ASSERT_EQUAL(expected, result);
917     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
918     memset(output, 0, sizeof(output));
919 
920     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
921     ctx.buf = malloc(buf_size);
922     ctx.size = 0;
923 
924     result = arm_convolve_wrapper_s4(&ctx,
925                                      &conv_params,
926                                      &quant_params,
927                                      &input_dims,
928                                      input_data,
929                                      &filter_dims,
930                                      kernel_data,
931                                      &bias_dims,
932                                      bias_data,
933                                      &output_dims,
934                                      output);
935 
936     if (ctx.buf)
937     {
938         memset(ctx.buf, 0, buf_size);
939         free(ctx.buf);
940     }
941     TEST_ASSERT_EQUAL(expected, result);
942     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
943 }
944 
conv_2x3_dilation_arm_convolve_s4(void)945 void conv_2x3_dilation_arm_convolve_s4(void)
946 {
947     int8_t output[CONV_2X3_DILATION_INT4_DST_SIZE] = {0};
948 
949     cmsis_nn_context ctx;
950     cmsis_nn_conv_params conv_params;
951     cmsis_nn_per_channel_quant_params quant_params;
952     cmsis_nn_dims input_dims;
953     cmsis_nn_dims filter_dims;
954     cmsis_nn_dims bias_dims;
955     cmsis_nn_dims output_dims;
956 
957     const int32_t *bias_data = conv_2x3_dilation_int4_biases;
958     const int8_t *kernel_data = conv_2x3_dilation_int4_weights;
959     const int8_t *input_data = conv_2x3_dilation_int4_input;
960     const int8_t *output_ref = conv_2x3_dilation_int4_output_ref;
961     const int32_t output_ref_size = CONV_2X3_DILATION_INT4_DST_SIZE;
962     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
963 
964     input_dims.n = CONV_2X3_DILATION_INT4_INPUT_BATCHES;
965     input_dims.w = CONV_2X3_DILATION_INT4_INPUT_W;
966     input_dims.h = CONV_2X3_DILATION_INT4_INPUT_H;
967     input_dims.c = CONV_2X3_DILATION_INT4_IN_CH;
968     filter_dims.w = CONV_2X3_DILATION_INT4_FILTER_X;
969     filter_dims.h = CONV_2X3_DILATION_INT4_FILTER_Y;
970     output_dims.w = CONV_2X3_DILATION_INT4_OUTPUT_W;
971     output_dims.h = CONV_2X3_DILATION_INT4_OUTPUT_H;
972     output_dims.c = CONV_2X3_DILATION_INT4_OUT_CH;
973 
974     conv_params.padding.w = CONV_2X3_DILATION_INT4_PAD_X;
975     conv_params.padding.h = CONV_2X3_DILATION_INT4_PAD_Y;
976     conv_params.stride.w = CONV_2X3_DILATION_INT4_STRIDE_X;
977     conv_params.stride.h = CONV_2X3_DILATION_INT4_STRIDE_Y;
978     conv_params.dilation.w = CONV_2X3_DILATION_INT4_DILATION_X;
979     conv_params.dilation.h = CONV_2X3_DILATION_INT4_DILATION_Y;
980 
981     conv_params.input_offset = CONV_2X3_DILATION_INT4_INPUT_OFFSET;
982     conv_params.output_offset = CONV_2X3_DILATION_INT4_OUTPUT_OFFSET;
983     conv_params.activation.min = CONV_2X3_DILATION_INT4_OUT_ACTIVATION_MIN;
984     conv_params.activation.max = CONV_2X3_DILATION_INT4_OUT_ACTIVATION_MAX;
985     quant_params.multiplier = (int32_t *)conv_2x3_dilation_int4_output_mult;
986     quant_params.shift = (int32_t *)conv_2x3_dilation_int4_output_shift;
987 
988     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
989     ctx.buf = malloc(buf_size);
990 
991     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
992                                                  &conv_params,
993                                                  &quant_params,
994                                                  &input_dims,
995                                                  input_data,
996                                                  &filter_dims,
997                                                  kernel_data,
998                                                  &bias_dims,
999                                                  bias_data,
1000                                                  &output_dims,
1001                                                  output);
1002     if (ctx.buf)
1003     {
1004         memset(ctx.buf, 0, buf_size);
1005         free(ctx.buf);
1006     }
1007     TEST_ASSERT_EQUAL(expected, result);
1008     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1009     memset(output, 0, sizeof(output));
1010 
1011     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1012     ctx.buf = malloc(buf_size);
1013     ctx.size = 0;
1014 
1015     result = arm_convolve_wrapper_s4(&ctx,
1016                                      &conv_params,
1017                                      &quant_params,
1018                                      &input_dims,
1019                                      input_data,
1020                                      &filter_dims,
1021                                      kernel_data,
1022                                      &bias_dims,
1023                                      bias_data,
1024                                      &output_dims,
1025                                      output);
1026 
1027     if (ctx.buf)
1028     {
1029         memset(ctx.buf, 0, buf_size);
1030         free(ctx.buf);
1031     }
1032     TEST_ASSERT_EQUAL(expected, result);
1033     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1034 }
1035 
conv_3x2_dilation_arm_convolve_s4(void)1036 void conv_3x2_dilation_arm_convolve_s4(void)
1037 {
1038     int8_t output[CONV_3X2_DILATION_INT4_DST_SIZE] = {0};
1039 
1040     cmsis_nn_context ctx;
1041     cmsis_nn_conv_params conv_params;
1042     cmsis_nn_per_channel_quant_params quant_params;
1043     cmsis_nn_dims input_dims;
1044     cmsis_nn_dims filter_dims;
1045     cmsis_nn_dims bias_dims;
1046     cmsis_nn_dims output_dims;
1047 
1048     const int32_t *bias_data = conv_3x2_dilation_int4_biases;
1049     const int8_t *kernel_data = conv_3x2_dilation_int4_weights;
1050     const int8_t *input_data = conv_3x2_dilation_int4_input;
1051     const int8_t *output_ref = conv_3x2_dilation_int4_output_ref;
1052     const int32_t output_ref_size = CONV_3X2_DILATION_INT4_DST_SIZE;
1053     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1054 
1055     input_dims.n = CONV_3X2_DILATION_INT4_INPUT_BATCHES;
1056     input_dims.w = CONV_3X2_DILATION_INT4_INPUT_W;
1057     input_dims.h = CONV_3X2_DILATION_INT4_INPUT_H;
1058     input_dims.c = CONV_3X2_DILATION_INT4_IN_CH;
1059     filter_dims.w = CONV_3X2_DILATION_INT4_FILTER_X;
1060     filter_dims.h = CONV_3X2_DILATION_INT4_FILTER_Y;
1061     output_dims.w = CONV_3X2_DILATION_INT4_OUTPUT_W;
1062     output_dims.h = CONV_3X2_DILATION_INT4_OUTPUT_H;
1063     output_dims.c = CONV_3X2_DILATION_INT4_OUT_CH;
1064 
1065     conv_params.padding.w = CONV_3X2_DILATION_INT4_PAD_X;
1066     conv_params.padding.h = CONV_3X2_DILATION_INT4_PAD_Y;
1067     conv_params.stride.w = CONV_3X2_DILATION_INT4_STRIDE_X;
1068     conv_params.stride.h = CONV_3X2_DILATION_INT4_STRIDE_Y;
1069     conv_params.dilation.w = CONV_3X2_DILATION_INT4_DILATION_X;
1070     conv_params.dilation.h = CONV_3X2_DILATION_INT4_DILATION_Y;
1071 
1072     conv_params.input_offset = CONV_3X2_DILATION_INT4_INPUT_OFFSET;
1073     conv_params.output_offset = CONV_3X2_DILATION_INT4_OUTPUT_OFFSET;
1074     conv_params.activation.min = CONV_3X2_DILATION_INT4_OUT_ACTIVATION_MIN;
1075     conv_params.activation.max = CONV_3X2_DILATION_INT4_OUT_ACTIVATION_MAX;
1076     quant_params.multiplier = (int32_t *)conv_3x2_dilation_int4_output_mult;
1077     quant_params.shift = (int32_t *)conv_3x2_dilation_int4_output_shift;
1078 
1079     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1080     ctx.buf = malloc(buf_size);
1081 
1082     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
1083                                                  &conv_params,
1084                                                  &quant_params,
1085                                                  &input_dims,
1086                                                  input_data,
1087                                                  &filter_dims,
1088                                                  kernel_data,
1089                                                  &bias_dims,
1090                                                  bias_data,
1091                                                  &output_dims,
1092                                                  output);
1093     if (ctx.buf)
1094     {
1095         memset(ctx.buf, 0, buf_size);
1096         free(ctx.buf);
1097     }
1098     TEST_ASSERT_EQUAL(expected, result);
1099     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1100     memset(output, 0, sizeof(output));
1101 
1102     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1103     ctx.buf = malloc(buf_size);
1104     ctx.size = 0;
1105 
1106     result = arm_convolve_wrapper_s4(&ctx,
1107                                      &conv_params,
1108                                      &quant_params,
1109                                      &input_dims,
1110                                      input_data,
1111                                      &filter_dims,
1112                                      kernel_data,
1113                                      &bias_dims,
1114                                      bias_data,
1115                                      &output_dims,
1116                                      output);
1117 
1118     if (ctx.buf)
1119     {
1120         memset(ctx.buf, 0, buf_size);
1121         free(ctx.buf);
1122     }
1123     TEST_ASSERT_EQUAL(expected, result);
1124     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1125 }
1126 
conv_dilation_golden_arm_convolve_s4(void)1127 void conv_dilation_golden_arm_convolve_s4(void)
1128 {
1129     int8_t output[CONV_DILATION_GOLDEN_INT4_DST_SIZE] = {0};
1130 
1131     cmsis_nn_context ctx;
1132     cmsis_nn_conv_params conv_params;
1133     cmsis_nn_per_channel_quant_params quant_params;
1134     cmsis_nn_dims input_dims;
1135     cmsis_nn_dims filter_dims;
1136     cmsis_nn_dims bias_dims;
1137     cmsis_nn_dims output_dims;
1138 
1139     const int32_t *bias_data = conv_dilation_golden_int4_biases;
1140     const int8_t *kernel_data = conv_dilation_golden_int4_weights;
1141     const int8_t *input_data = conv_dilation_golden_int4_input;
1142     const int8_t *output_ref = conv_dilation_golden_int4_output_ref;
1143     const int32_t output_ref_size = CONV_DILATION_GOLDEN_INT4_DST_SIZE;
1144     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1145 
1146     input_dims.n = CONV_DILATION_GOLDEN_INT4_INPUT_BATCHES;
1147     input_dims.w = CONV_DILATION_GOLDEN_INT4_INPUT_W;
1148     input_dims.h = CONV_DILATION_GOLDEN_INT4_INPUT_H;
1149     input_dims.c = CONV_DILATION_GOLDEN_INT4_IN_CH;
1150     filter_dims.w = CONV_DILATION_GOLDEN_INT4_FILTER_X;
1151     filter_dims.h = CONV_DILATION_GOLDEN_INT4_FILTER_Y;
1152     output_dims.w = CONV_DILATION_GOLDEN_INT4_OUTPUT_W;
1153     output_dims.h = CONV_DILATION_GOLDEN_INT4_OUTPUT_H;
1154     output_dims.c = CONV_DILATION_GOLDEN_INT4_OUT_CH;
1155 
1156     conv_params.padding.w = CONV_DILATION_GOLDEN_INT4_PAD_X;
1157     conv_params.padding.h = CONV_DILATION_GOLDEN_INT4_PAD_Y;
1158     conv_params.stride.w = CONV_DILATION_GOLDEN_INT4_STRIDE_X;
1159     conv_params.stride.h = CONV_DILATION_GOLDEN_INT4_STRIDE_Y;
1160     conv_params.dilation.w = CONV_DILATION_GOLDEN_INT4_DILATION_X;
1161     conv_params.dilation.h = CONV_DILATION_GOLDEN_INT4_DILATION_Y;
1162 
1163     conv_params.input_offset = CONV_DILATION_GOLDEN_INT4_INPUT_OFFSET;
1164     conv_params.output_offset = CONV_DILATION_GOLDEN_INT4_OUTPUT_OFFSET;
1165     conv_params.activation.min = CONV_DILATION_GOLDEN_INT4_OUT_ACTIVATION_MIN;
1166     conv_params.activation.max = CONV_DILATION_GOLDEN_INT4_OUT_ACTIVATION_MAX;
1167     quant_params.multiplier = (int32_t *)conv_dilation_golden_int4_output_mult;
1168     quant_params.shift = (int32_t *)conv_dilation_golden_int4_output_shift;
1169 
1170     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1171     ctx.buf = malloc(buf_size);
1172 
1173     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
1174                                                  &conv_params,
1175                                                  &quant_params,
1176                                                  &input_dims,
1177                                                  input_data,
1178                                                  &filter_dims,
1179                                                  kernel_data,
1180                                                  &bias_dims,
1181                                                  bias_data,
1182                                                  &output_dims,
1183                                                  output);
1184     if (ctx.buf)
1185     {
1186         memset(ctx.buf, 0, buf_size);
1187         free(ctx.buf);
1188     }
1189     TEST_ASSERT_EQUAL(expected, result);
1190     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1191     memset(output, 0, sizeof(output));
1192 
1193     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1194     ctx.buf = malloc(buf_size);
1195     ctx.size = 0;
1196 
1197     result = arm_convolve_wrapper_s4(&ctx,
1198                                      &conv_params,
1199                                      &quant_params,
1200                                      &input_dims,
1201                                      input_data,
1202                                      &filter_dims,
1203                                      kernel_data,
1204                                      &bias_dims,
1205                                      bias_data,
1206                                      &output_dims,
1207                                      output);
1208 
1209     if (ctx.buf)
1210     {
1211         memset(ctx.buf, 0, buf_size);
1212         free(ctx.buf);
1213     }
1214     TEST_ASSERT_EQUAL(expected, result);
1215     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1216 }
1217 
conv_5_arm_convolve_s4(void)1218 void conv_5_arm_convolve_s4(void)
1219 {
1220     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1221     int8_t output[CONV_5_INT4_DST_SIZE] = {0};
1222 
1223     cmsis_nn_context ctx;
1224     cmsis_nn_conv_params conv_params;
1225     cmsis_nn_per_channel_quant_params quant_params;
1226     cmsis_nn_dims input_dims;
1227     cmsis_nn_dims filter_dims;
1228     cmsis_nn_dims bias_dims;
1229     cmsis_nn_dims output_dims;
1230 
1231     const int32_t *bias_data = conv_5_int4_biases;
1232     const int8_t *kernel_data = conv_5_int4_weights;
1233     const int8_t *input_data = conv_5_int4_input;
1234     const int8_t *output_ref = conv_5_int4_output_ref;
1235     const int32_t output_ref_size = CONV_5_INT4_DST_SIZE;
1236 
1237     input_dims.n = CONV_5_INT4_INPUT_BATCHES;
1238     input_dims.w = CONV_5_INT4_INPUT_W;
1239     input_dims.h = CONV_5_INT4_INPUT_H;
1240     input_dims.c = CONV_5_INT4_IN_CH;
1241     filter_dims.w = CONV_5_INT4_FILTER_X;
1242     filter_dims.h = CONV_5_INT4_FILTER_Y;
1243     output_dims.w = CONV_5_INT4_OUTPUT_W;
1244     output_dims.h = CONV_5_INT4_OUTPUT_H;
1245     output_dims.c = CONV_5_INT4_OUT_CH;
1246 
1247     conv_params.padding.w = CONV_5_INT4_PAD_X;
1248     conv_params.padding.h = CONV_5_INT4_PAD_Y;
1249     conv_params.stride.w = CONV_5_INT4_STRIDE_X;
1250     conv_params.stride.h = CONV_5_INT4_STRIDE_Y;
1251     conv_params.dilation.w = CONV_5_INT4_DILATION_X;
1252     conv_params.dilation.h = CONV_5_INT4_DILATION_Y;
1253 
1254     conv_params.input_offset = CONV_5_INT4_INPUT_OFFSET;
1255     conv_params.output_offset = CONV_5_INT4_OUTPUT_OFFSET;
1256     conv_params.activation.min = CONV_5_INT4_OUT_ACTIVATION_MIN;
1257     conv_params.activation.max = CONV_5_INT4_OUT_ACTIVATION_MAX;
1258     quant_params.multiplier = (int32_t *)conv_5_int4_output_mult;
1259     quant_params.shift = (int32_t *)conv_5_int4_output_shift;
1260 
1261     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1262     ctx.buf = malloc(buf_size);
1263     ctx.size = 0;
1264 
1265     arm_cmsis_nn_status result = arm_convolve_s4(&ctx,
1266                                                  &conv_params,
1267                                                  &quant_params,
1268                                                  &input_dims,
1269                                                  input_data,
1270                                                  &filter_dims,
1271                                                  conv_5_int4_weights,
1272                                                  &bias_dims,
1273                                                  bias_data,
1274                                                  &output_dims,
1275                                                  output);
1276 
1277     if (ctx.buf)
1278     {
1279         memset(ctx.buf, 0, buf_size);
1280         free(ctx.buf);
1281     }
1282     TEST_ASSERT_EQUAL(expected, result);
1283     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1284     memset(output, 0, sizeof(output));
1285 
1286     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1287     ctx.buf = malloc(buf_size);
1288     ctx.size = 0;
1289 
1290     result = arm_convolve_wrapper_s4(&ctx,
1291                                      &conv_params,
1292                                      &quant_params,
1293                                      &input_dims,
1294                                      input_data,
1295                                      &filter_dims,
1296                                      kernel_data,
1297                                      &bias_dims,
1298                                      bias_data,
1299                                      &output_dims,
1300                                      output);
1301 
1302     if (ctx.buf)
1303     {
1304         memset(ctx.buf, 0, buf_size);
1305         free(ctx.buf);
1306     }
1307     TEST_ASSERT_EQUAL(expected, result);
1308     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1309 }
1310 
buffer_size_arm_convolve_s4(void)1311 void buffer_size_arm_convolve_s4(void)
1312 {
1313     cmsis_nn_conv_params conv_params;
1314     cmsis_nn_dims input_dims;
1315     cmsis_nn_dims filter_dims;
1316     cmsis_nn_dims output_dims;
1317 
1318     input_dims.n = CONV_5_INT4_INPUT_BATCHES;
1319     input_dims.w = CONV_5_INT4_INPUT_W;
1320     input_dims.h = CONV_5_INT4_INPUT_H;
1321     input_dims.c = CONV_5_INT4_IN_CH;
1322     filter_dims.w = CONV_5_INT4_FILTER_X;
1323     filter_dims.h = CONV_5_INT4_FILTER_Y;
1324     output_dims.w = CONV_5_INT4_OUTPUT_W;
1325     output_dims.h = CONV_5_INT4_OUTPUT_H;
1326     output_dims.c = CONV_5_INT4_OUT_CH;
1327 
1328     conv_params.padding.w = CONV_5_INT4_PAD_X;
1329     conv_params.padding.h = CONV_5_INT4_PAD_Y;
1330     conv_params.stride.w = CONV_5_INT4_STRIDE_X;
1331     conv_params.stride.h = CONV_5_INT4_STRIDE_Y;
1332     conv_params.dilation.w = CONV_5_INT4_DILATION_X;
1333     conv_params.dilation.h = CONV_5_INT4_DILATION_Y;
1334 
1335     conv_params.input_offset = CONV_5_INT4_INPUT_OFFSET;
1336     conv_params.output_offset = CONV_5_INT4_OUTPUT_OFFSET;
1337     conv_params.activation.min = CONV_5_INT4_OUT_ACTIVATION_MIN;
1338     conv_params.activation.max = CONV_5_INT4_OUT_ACTIVATION_MAX;
1339 
1340     const int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1341     const int32_t wrapper_buf_size =
1342         arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1343 
1344     TEST_ASSERT_EQUAL(wrapper_buf_size, buf_size);
1345 }
1346 
buffer_size_mve_arm_convolve_s4(void)1347 void buffer_size_mve_arm_convolve_s4(void)
1348 {
1349 #if defined(ARM_MATH_MVEI)
1350     cmsis_nn_conv_params conv_params;
1351     cmsis_nn_dims input_dims;
1352     cmsis_nn_dims filter_dims;
1353     cmsis_nn_dims output_dims;
1354 
1355     input_dims.n = CONV_5_INT4_INPUT_BATCHES;
1356     input_dims.w = CONV_5_INT4_INPUT_W;
1357     input_dims.h = CONV_5_INT4_INPUT_H;
1358     input_dims.c = CONV_5_INT4_IN_CH;
1359     filter_dims.w = CONV_5_INT4_FILTER_X;
1360     filter_dims.h = CONV_5_INT4_FILTER_Y;
1361     output_dims.w = CONV_5_INT4_OUTPUT_W;
1362     output_dims.h = CONV_5_INT4_OUTPUT_H;
1363     output_dims.c = CONV_5_INT4_OUT_CH;
1364 
1365     conv_params.padding.w = CONV_5_INT4_PAD_X;
1366     conv_params.padding.h = CONV_5_INT4_PAD_Y;
1367     conv_params.stride.w = CONV_5_INT4_STRIDE_X;
1368     conv_params.stride.h = CONV_5_INT4_STRIDE_Y;
1369     conv_params.dilation.w = CONV_5_INT4_DILATION_X;
1370     conv_params.dilation.h = CONV_5_INT4_DILATION_Y;
1371 
1372     conv_params.input_offset = CONV_5_INT4_INPUT_OFFSET;
1373     conv_params.output_offset = CONV_5_INT4_OUTPUT_OFFSET;
1374     conv_params.activation.min = CONV_5_INT4_OUT_ACTIVATION_MIN;
1375     conv_params.activation.max = CONV_5_INT4_OUT_ACTIVATION_MAX;
1376 
1377     const int32_t wrapper_buf_size =
1378         arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1379     const int32_t mve_wrapper_buf_size =
1380         arm_convolve_wrapper_s4_get_buffer_size_mve(&conv_params, &input_dims, &filter_dims, &output_dims);
1381 
1382     TEST_ASSERT_EQUAL(wrapper_buf_size, mve_wrapper_buf_size);
1383 #endif
1384 }
1385 
buffer_size_dsp_arm_convolve_s4(void)1386 void buffer_size_dsp_arm_convolve_s4(void)
1387 {
1388 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
1389     cmsis_nn_conv_params conv_params;
1390     cmsis_nn_dims input_dims;
1391     cmsis_nn_dims filter_dims;
1392     cmsis_nn_dims output_dims;
1393 
1394     input_dims.n = CONV_5_INT4_INPUT_BATCHES;
1395     input_dims.w = CONV_5_INT4_INPUT_W;
1396     input_dims.h = CONV_5_INT4_INPUT_H;
1397     input_dims.c = CONV_5_INT4_IN_CH;
1398     filter_dims.w = CONV_5_INT4_FILTER_X;
1399     filter_dims.h = CONV_5_INT4_FILTER_Y;
1400     output_dims.w = CONV_5_INT4_OUTPUT_W;
1401     output_dims.h = CONV_5_INT4_OUTPUT_H;
1402     output_dims.c = CONV_5_INT4_OUT_CH;
1403 
1404     conv_params.padding.w = CONV_5_INT4_PAD_X;
1405     conv_params.padding.h = CONV_5_INT4_PAD_Y;
1406     conv_params.stride.w = CONV_5_INT4_STRIDE_X;
1407     conv_params.stride.h = CONV_5_INT4_STRIDE_Y;
1408     conv_params.dilation.w = CONV_5_INT4_DILATION_X;
1409     conv_params.dilation.h = CONV_5_INT4_DILATION_Y;
1410 
1411     conv_params.input_offset = CONV_5_INT4_INPUT_OFFSET;
1412     conv_params.output_offset = CONV_5_INT4_OUTPUT_OFFSET;
1413     conv_params.activation.min = CONV_5_INT4_OUT_ACTIVATION_MIN;
1414     conv_params.activation.max = CONV_5_INT4_OUT_ACTIVATION_MAX;
1415 
1416     const int32_t wrapper_buf_size =
1417         arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1418     const int32_t dsp_wrapper_buf_size =
1419         arm_convolve_wrapper_s4_get_buffer_size_dsp(&conv_params, &input_dims, &filter_dims, &output_dims);
1420 
1421     TEST_ASSERT_EQUAL(wrapper_buf_size, dsp_wrapper_buf_size);
1422 #endif
1423 }
1424 
conv_1_x_n_1_arm_convolve_s4(void)1425 void conv_1_x_n_1_arm_convolve_s4(void)
1426 {
1427     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1428     int8_t output[CONV_1_X_N_1_INT4_DST_SIZE] = {0};
1429 
1430     cmsis_nn_context ctx;
1431     cmsis_nn_conv_params conv_params;
1432     cmsis_nn_per_channel_quant_params quant_params;
1433     cmsis_nn_dims input_dims;
1434     cmsis_nn_dims filter_dims;
1435     cmsis_nn_dims bias_dims;
1436     cmsis_nn_dims output_dims;
1437 
1438     const int32_t *bias_data = conv_1_x_n_1_int4_biases;
1439     const int8_t *kernel_data = conv_1_x_n_1_int4_weights;
1440     const int8_t *input_data = conv_1_x_n_1_int4_input;
1441     const int8_t *output_ref = conv_1_x_n_1_int4_output_ref;
1442     const int32_t output_ref_size = CONV_1_X_N_1_INT4_DST_SIZE;
1443 
1444     input_dims.n = CONV_1_X_N_1_INT4_INPUT_BATCHES;
1445     input_dims.w = CONV_1_X_N_1_INT4_INPUT_W;
1446     input_dims.h = CONV_1_X_N_1_INT4_INPUT_H;
1447     input_dims.c = CONV_1_X_N_1_INT4_IN_CH;
1448     filter_dims.w = CONV_1_X_N_1_INT4_FILTER_X;
1449     filter_dims.h = CONV_1_X_N_1_INT4_FILTER_Y;
1450     output_dims.w = CONV_1_X_N_1_INT4_OUTPUT_W;
1451     output_dims.h = CONV_1_X_N_1_INT4_OUTPUT_H;
1452     output_dims.c = CONV_1_X_N_1_INT4_OUT_CH;
1453 
1454     conv_params.padding.w = CONV_1_X_N_1_INT4_PAD_X;
1455     conv_params.padding.h = CONV_1_X_N_1_INT4_PAD_Y;
1456     conv_params.stride.w = CONV_1_X_N_1_INT4_STRIDE_X;
1457     conv_params.stride.h = CONV_1_X_N_1_INT4_STRIDE_Y;
1458     conv_params.dilation.w = CONV_1_X_N_1_INT4_DILATION_X;
1459     conv_params.dilation.h = CONV_1_X_N_1_INT4_DILATION_Y;
1460 
1461     conv_params.input_offset = CONV_1_X_N_1_INT4_INPUT_OFFSET;
1462     conv_params.output_offset = CONV_1_X_N_1_INT4_OUTPUT_OFFSET;
1463     conv_params.activation.min = CONV_1_X_N_1_INT4_OUT_ACTIVATION_MIN;
1464     conv_params.activation.max = CONV_1_X_N_1_INT4_OUT_ACTIVATION_MAX;
1465     quant_params.multiplier = (int32_t *)conv_1_x_n_1_int4_output_mult;
1466     quant_params.shift = (int32_t *)conv_1_x_n_1_int4_output_shift;
1467 
1468     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1469     ctx.buf = malloc(buf_size);
1470     ctx.size = 0;
1471 
1472     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
1473                                                          &conv_params,
1474                                                          &quant_params,
1475                                                          &input_dims,
1476                                                          input_data,
1477                                                          &filter_dims,
1478                                                          kernel_data,
1479                                                          &bias_dims,
1480                                                          bias_data,
1481                                                          &output_dims,
1482                                                          output);
1483 
1484     if (ctx.buf)
1485     {
1486         memset(ctx.buf, 0, buf_size);
1487         free(ctx.buf);
1488     }
1489     TEST_ASSERT_EQUAL(expected, result);
1490     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1491     memset(output, 0, sizeof(output));
1492 
1493     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1494     ctx.buf = malloc(buf_size);
1495 
1496     result = arm_convolve_1_x_n_s4(&ctx,
1497                                    &conv_params,
1498                                    &quant_params,
1499                                    &input_dims,
1500                                    input_data,
1501                                    &filter_dims,
1502                                    kernel_data,
1503                                    &bias_dims,
1504                                    bias_data,
1505                                    &output_dims,
1506                                    output);
1507     if (ctx.buf)
1508     {
1509         memset(ctx.buf, 0, buf_size);
1510         free(ctx.buf);
1511     }
1512     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1513     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1514 }
1515 
conv_1_x_n_2_arm_convolve_s4(void)1516 void conv_1_x_n_2_arm_convolve_s4(void)
1517 {
1518     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1519     int8_t output[CONV_1_X_N_2_INT4_DST_SIZE] = {0};
1520 
1521     cmsis_nn_context ctx;
1522     cmsis_nn_conv_params conv_params;
1523     cmsis_nn_per_channel_quant_params quant_params;
1524     cmsis_nn_dims input_dims;
1525     cmsis_nn_dims filter_dims;
1526     cmsis_nn_dims bias_dims;
1527     cmsis_nn_dims output_dims;
1528 
1529     const int32_t *bias_data = conv_1_x_n_2_int4_biases;
1530     const int8_t *kernel_data = conv_1_x_n_2_int4_weights;
1531     const int8_t *input_data = conv_1_x_n_2_int4_input;
1532     const int8_t *output_ref = conv_1_x_n_2_int4_output_ref;
1533     const int32_t output_ref_size = CONV_1_X_N_2_INT4_DST_SIZE;
1534 
1535     input_dims.n = CONV_1_X_N_2_INT4_INPUT_BATCHES;
1536     input_dims.w = CONV_1_X_N_2_INT4_INPUT_W;
1537     input_dims.h = CONV_1_X_N_2_INT4_INPUT_H;
1538     input_dims.c = CONV_1_X_N_2_INT4_IN_CH;
1539     filter_dims.w = CONV_1_X_N_2_INT4_FILTER_X;
1540     filter_dims.h = CONV_1_X_N_2_INT4_FILTER_Y;
1541     output_dims.w = CONV_1_X_N_2_INT4_OUTPUT_W;
1542     output_dims.h = CONV_1_X_N_2_INT4_OUTPUT_H;
1543     output_dims.c = CONV_1_X_N_2_INT4_OUT_CH;
1544 
1545     conv_params.padding.w = CONV_1_X_N_2_INT4_PAD_X;
1546     conv_params.padding.h = CONV_1_X_N_2_INT4_PAD_Y;
1547     conv_params.stride.w = CONV_1_X_N_2_INT4_STRIDE_X;
1548     conv_params.stride.h = CONV_1_X_N_2_INT4_STRIDE_Y;
1549     conv_params.dilation.w = CONV_1_X_N_2_INT4_DILATION_X;
1550     conv_params.dilation.h = CONV_1_X_N_2_INT4_DILATION_Y;
1551 
1552     conv_params.input_offset = CONV_1_X_N_2_INT4_INPUT_OFFSET;
1553     conv_params.output_offset = CONV_1_X_N_2_INT4_OUTPUT_OFFSET;
1554     conv_params.activation.min = CONV_1_X_N_2_INT4_OUT_ACTIVATION_MIN;
1555     conv_params.activation.max = CONV_1_X_N_2_INT4_OUT_ACTIVATION_MAX;
1556     quant_params.multiplier = (int32_t *)conv_1_x_n_2_int4_output_mult;
1557     quant_params.shift = (int32_t *)conv_1_x_n_2_int4_output_shift;
1558 
1559     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1560     ctx.buf = malloc(buf_size);
1561     ctx.size = 0;
1562 
1563     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
1564                                                          &conv_params,
1565                                                          &quant_params,
1566                                                          &input_dims,
1567                                                          input_data,
1568                                                          &filter_dims,
1569                                                          kernel_data,
1570                                                          &bias_dims,
1571                                                          bias_data,
1572                                                          &output_dims,
1573                                                          output);
1574 
1575     if (ctx.buf)
1576     {
1577         memset(ctx.buf, 0, buf_size);
1578         free(ctx.buf);
1579     }
1580     TEST_ASSERT_EQUAL(expected, result);
1581     memset(output, 0, sizeof(output));
1582 
1583     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1584 
1585     ctx.buf = malloc(buf_size);
1586 
1587     result = arm_convolve_1_x_n_s4(&ctx,
1588                                    &conv_params,
1589                                    &quant_params,
1590                                    &input_dims,
1591                                    input_data,
1592                                    &filter_dims,
1593                                    kernel_data,
1594                                    &bias_dims,
1595                                    bias_data,
1596                                    &output_dims,
1597                                    output);
1598     if (ctx.buf)
1599     {
1600         memset(ctx.buf, 0, buf_size);
1601         free(ctx.buf);
1602     }
1603     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1604     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1605 }
1606 
conv_1_x_n_3_arm_convolve_s4(void)1607 void conv_1_x_n_3_arm_convolve_s4(void)
1608 {
1609     int8_t output[CONV_1_X_N_3_INT4_DST_SIZE] = {0};
1610 
1611     cmsis_nn_context ctx;
1612     cmsis_nn_conv_params conv_params;
1613     cmsis_nn_per_channel_quant_params quant_params;
1614     cmsis_nn_dims input_dims;
1615     cmsis_nn_dims filter_dims;
1616     cmsis_nn_dims bias_dims;
1617     cmsis_nn_dims output_dims;
1618 
1619     const int32_t *bias_data = conv_1_x_n_3_int4_biases;
1620     const int8_t *kernel_data = conv_1_x_n_3_int4_weights;
1621     const int8_t *input_data = conv_1_x_n_3_int4_input;
1622     const int8_t *output_ref = conv_1_x_n_3_int4_output_ref;
1623     const int32_t output_ref_size = CONV_1_X_N_3_INT4_DST_SIZE;
1624 
1625     input_dims.n = CONV_1_X_N_3_INT4_INPUT_BATCHES;
1626     input_dims.w = CONV_1_X_N_3_INT4_INPUT_W;
1627     input_dims.h = CONV_1_X_N_3_INT4_INPUT_H;
1628     input_dims.c = CONV_1_X_N_3_INT4_IN_CH;
1629     filter_dims.w = CONV_1_X_N_3_INT4_FILTER_X;
1630     filter_dims.h = CONV_1_X_N_3_INT4_FILTER_Y;
1631     output_dims.w = CONV_1_X_N_3_INT4_OUTPUT_W;
1632     output_dims.h = CONV_1_X_N_3_INT4_OUTPUT_H;
1633     output_dims.c = CONV_1_X_N_3_INT4_OUT_CH;
1634 
1635     conv_params.padding.w = CONV_1_X_N_3_INT4_PAD_X;
1636     conv_params.padding.h = CONV_1_X_N_3_INT4_PAD_Y;
1637     conv_params.stride.w = CONV_1_X_N_3_INT4_STRIDE_X;
1638     conv_params.stride.h = CONV_1_X_N_3_INT4_STRIDE_Y;
1639     conv_params.dilation.w = CONV_1_X_N_3_INT4_DILATION_X;
1640     conv_params.dilation.h = CONV_1_X_N_3_INT4_DILATION_Y;
1641 
1642     conv_params.input_offset = CONV_1_X_N_3_INT4_INPUT_OFFSET;
1643     conv_params.output_offset = CONV_1_X_N_3_INT4_OUTPUT_OFFSET;
1644     conv_params.activation.min = CONV_1_X_N_3_INT4_OUT_ACTIVATION_MIN;
1645     conv_params.activation.max = CONV_1_X_N_3_INT4_OUT_ACTIVATION_MAX;
1646     quant_params.multiplier = (int32_t *)conv_1_x_n_3_int4_output_mult;
1647     quant_params.shift = (int32_t *)conv_1_x_n_3_int4_output_shift;
1648 
1649     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1650     ctx.buf = malloc(buf_size);
1651     ctx.size = 0;
1652 
1653     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
1654                                                          &conv_params,
1655                                                          &quant_params,
1656                                                          &input_dims,
1657                                                          input_data,
1658                                                          &filter_dims,
1659                                                          kernel_data,
1660                                                          &bias_dims,
1661                                                          bias_data,
1662                                                          &output_dims,
1663                                                          output);
1664 
1665     if (ctx.buf)
1666     {
1667         memset(ctx.buf, 0, buf_size);
1668         free(ctx.buf);
1669     }
1670     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1671     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1672     memset(output, 0, sizeof(output));
1673 
1674     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1675     ctx.buf = malloc(buf_size);
1676 
1677     result = arm_convolve_1_x_n_s4(&ctx,
1678                                    &conv_params,
1679                                    &quant_params,
1680                                    &input_dims,
1681                                    input_data,
1682                                    &filter_dims,
1683                                    kernel_data,
1684                                    &bias_dims,
1685                                    bias_data,
1686                                    &output_dims,
1687                                    output);
1688     if (ctx.buf)
1689     {
1690         memset(ctx.buf, 0, buf_size);
1691         free(ctx.buf);
1692     }
1693     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1694     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1695 }
1696 
conv_1_x_n_4_arm_convolve_s4(void)1697 void conv_1_x_n_4_arm_convolve_s4(void)
1698 {
1699     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1700     int8_t output[CONV_1_X_N_4_INT4_DST_SIZE] = {0};
1701 
1702     cmsis_nn_context ctx;
1703     cmsis_nn_conv_params conv_params;
1704     cmsis_nn_per_channel_quant_params quant_params;
1705     cmsis_nn_dims input_dims;
1706     cmsis_nn_dims filter_dims;
1707     cmsis_nn_dims bias_dims;
1708     cmsis_nn_dims output_dims;
1709 
1710     const int32_t *bias_data = conv_1_x_n_4_int4_biases;
1711     const int8_t *kernel_data = conv_1_x_n_4_int4_weights;
1712     const int8_t *input_data = conv_1_x_n_4_int4_input;
1713     const int8_t *output_ref = conv_1_x_n_4_int4_output_ref;
1714     const int32_t output_ref_size = CONV_1_X_N_4_INT4_DST_SIZE;
1715 
1716     input_dims.n = CONV_1_X_N_4_INT4_INPUT_BATCHES;
1717     input_dims.w = CONV_1_X_N_4_INT4_INPUT_W;
1718     input_dims.h = CONV_1_X_N_4_INT4_INPUT_H;
1719     input_dims.c = CONV_1_X_N_4_INT4_IN_CH;
1720     filter_dims.w = CONV_1_X_N_4_INT4_FILTER_X;
1721     filter_dims.h = CONV_1_X_N_4_INT4_FILTER_Y;
1722     output_dims.w = CONV_1_X_N_4_INT4_OUTPUT_W;
1723     output_dims.h = CONV_1_X_N_4_INT4_OUTPUT_H;
1724     output_dims.c = CONV_1_X_N_4_INT4_OUT_CH;
1725 
1726     conv_params.padding.w = CONV_1_X_N_4_INT4_PAD_X;
1727     conv_params.padding.h = CONV_1_X_N_4_INT4_PAD_Y;
1728     conv_params.stride.w = CONV_1_X_N_4_INT4_STRIDE_X;
1729     conv_params.stride.h = CONV_1_X_N_4_INT4_STRIDE_Y;
1730     conv_params.dilation.w = CONV_1_X_N_4_INT4_DILATION_X;
1731     conv_params.dilation.h = CONV_1_X_N_4_INT4_DILATION_Y;
1732 
1733     conv_params.input_offset = CONV_1_X_N_4_INT4_INPUT_OFFSET;
1734     conv_params.output_offset = CONV_1_X_N_4_INT4_OUTPUT_OFFSET;
1735     conv_params.activation.min = CONV_1_X_N_4_INT4_OUT_ACTIVATION_MIN;
1736     conv_params.activation.max = CONV_1_X_N_4_INT4_OUT_ACTIVATION_MAX;
1737     quant_params.multiplier = (int32_t *)conv_1_x_n_4_int4_output_mult;
1738     quant_params.shift = (int32_t *)conv_1_x_n_4_int4_output_shift;
1739 
1740     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1741 
1742     ctx.buf = malloc(buf_size);
1743     ctx.size = 0;
1744 
1745     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
1746                                                          &conv_params,
1747                                                          &quant_params,
1748                                                          &input_dims,
1749                                                          input_data,
1750                                                          &filter_dims,
1751                                                          kernel_data,
1752                                                          &bias_dims,
1753                                                          bias_data,
1754                                                          &output_dims,
1755                                                          output);
1756 
1757     if (ctx.buf)
1758     {
1759         memset(ctx.buf, 0, buf_size);
1760         free(ctx.buf);
1761     }
1762     TEST_ASSERT_EQUAL(expected, result);
1763     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1764     memset(output, 0, sizeof(output));
1765 
1766     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1767 
1768     ctx.buf = malloc(buf_size);
1769 
1770     result = arm_convolve_1_x_n_s4(&ctx,
1771                                    &conv_params,
1772                                    &quant_params,
1773                                    &input_dims,
1774                                    input_data,
1775                                    &filter_dims,
1776                                    kernel_data,
1777                                    &bias_dims,
1778                                    bias_data,
1779                                    &output_dims,
1780                                    output);
1781     if (ctx.buf)
1782     {
1783         memset(ctx.buf, 0, buf_size);
1784         free(ctx.buf);
1785     }
1786     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1787     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1788 }
1789 
conv_1_x_n_5_arm_convolve_s4(void)1790 void conv_1_x_n_5_arm_convolve_s4(void)
1791 {
1792     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
1793     int8_t output[CONV_1_X_N_5_INT4_DST_SIZE] = {0};
1794 
1795     cmsis_nn_context ctx;
1796     cmsis_nn_conv_params conv_params;
1797     cmsis_nn_per_channel_quant_params quant_params;
1798     cmsis_nn_dims input_dims;
1799     cmsis_nn_dims filter_dims;
1800     cmsis_nn_dims bias_dims;
1801     cmsis_nn_dims output_dims;
1802 
1803     const int32_t *bias_data = conv_1_x_n_5_int4_biases;
1804     const int8_t *kernel_data = conv_1_x_n_5_int4_weights;
1805     const int8_t *input_data = conv_1_x_n_5_int4_input;
1806     const int8_t *output_ref = conv_1_x_n_5_int4_output_ref;
1807     const int32_t output_ref_size = CONV_1_X_N_5_INT4_DST_SIZE;
1808 
1809     input_dims.n = CONV_1_X_N_5_INT4_INPUT_BATCHES;
1810     input_dims.w = CONV_1_X_N_5_INT4_INPUT_W;
1811     input_dims.h = CONV_1_X_N_5_INT4_INPUT_H;
1812     input_dims.c = CONV_1_X_N_5_INT4_IN_CH;
1813     filter_dims.w = CONV_1_X_N_5_INT4_FILTER_X;
1814     filter_dims.h = CONV_1_X_N_5_INT4_FILTER_Y;
1815     output_dims.w = CONV_1_X_N_5_INT4_OUTPUT_W;
1816     output_dims.h = CONV_1_X_N_5_INT4_OUTPUT_H;
1817     output_dims.c = CONV_1_X_N_5_INT4_OUT_CH;
1818 
1819     conv_params.padding.w = CONV_1_X_N_5_INT4_PAD_X;
1820     conv_params.padding.h = CONV_1_X_N_5_INT4_PAD_Y;
1821     conv_params.stride.w = CONV_1_X_N_5_INT4_STRIDE_X;
1822     conv_params.stride.h = CONV_1_X_N_5_INT4_STRIDE_Y;
1823     conv_params.dilation.w = CONV_1_X_N_5_INT4_DILATION_X;
1824     conv_params.dilation.h = CONV_1_X_N_5_INT4_DILATION_Y;
1825 
1826     conv_params.input_offset = CONV_1_X_N_5_INT4_INPUT_OFFSET;
1827     conv_params.output_offset = CONV_1_X_N_5_INT4_OUTPUT_OFFSET;
1828     conv_params.activation.min = CONV_1_X_N_5_INT4_OUT_ACTIVATION_MIN;
1829     conv_params.activation.max = CONV_1_X_N_5_INT4_OUT_ACTIVATION_MAX;
1830     quant_params.multiplier = (int32_t *)conv_1_x_n_5_int4_output_mult;
1831     quant_params.shift = (int32_t *)conv_1_x_n_5_int4_output_shift;
1832 
1833     int32_t buf_size = arm_convolve_s4_get_buffer_size(&input_dims, &filter_dims);
1834     ctx.buf = malloc(buf_size);
1835     ctx.size = 0;
1836 
1837     arm_cmsis_nn_status result = arm_convolve_wrapper_s4(&ctx,
1838                                                          &conv_params,
1839                                                          &quant_params,
1840                                                          &input_dims,
1841                                                          input_data,
1842                                                          &filter_dims,
1843                                                          kernel_data,
1844                                                          &bias_dims,
1845                                                          bias_data,
1846                                                          &output_dims,
1847                                                          output);
1848 
1849     if (ctx.buf)
1850     {
1851         memset(ctx.buf, 0, buf_size);
1852         free(ctx.buf);
1853     }
1854     TEST_ASSERT_EQUAL(expected, result);
1855     memset(output, 0, sizeof(output));
1856 
1857     buf_size = arm_convolve_wrapper_s4_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
1858     ctx.buf = malloc(buf_size);
1859 
1860     result = arm_convolve_1_x_n_s4(&ctx,
1861                                    &conv_params,
1862                                    &quant_params,
1863                                    &input_dims,
1864                                    input_data,
1865                                    &filter_dims,
1866                                    kernel_data,
1867                                    &bias_dims,
1868                                    bias_data,
1869                                    &output_dims,
1870                                    output);
1871     if (ctx.buf)
1872     {
1873         memset(ctx.buf, 0, buf_size);
1874         free(ctx.buf);
1875     }
1876     TEST_ASSERT_EQUAL(ARM_CMSIS_NN_SUCCESS, result);
1877     TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
1878 }
1879