1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <arm_nnfunctions.h>
20 #include <unity.h>
21 
22 #include "../TestData/dw_int16xint8/test_data.h"
23 #include "../TestData/dw_int16xint8_dilation/test_data.h"
24 #include "../TestData/dw_int16xint8_mult4/test_data.h"
25 #include "../Utils/validate.h"
26 
dw_int16xint8_arm_depthwise_conv_s16(void)27 void dw_int16xint8_arm_depthwise_conv_s16(void)
28 {
29     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
30     int16_t output[DW_INT16XINT8_DST_SIZE] = {0};
31 
32     cmsis_nn_context ctx;
33     cmsis_nn_dw_conv_params dw_conv_params;
34     cmsis_nn_per_channel_quant_params quant_params;
35     cmsis_nn_dims input_dims;
36     cmsis_nn_dims filter_dims;
37     cmsis_nn_dims bias_dims = {};
38     cmsis_nn_dims output_dims;
39 
40     const int64_t *bias_data = dw_int16xint8_biases;
41     const int16_t *input_data = dw_int16xint8_input;
42     const int8_t *kernel_data = dw_int16xint8_weights;
43     const int16_t *output_ref = dw_int16xint8_output_ref;
44     const int32_t output_ref_size = DW_INT16XINT8_DST_SIZE;
45 
46     input_dims.n = DW_INT16XINT8_INPUT_BATCHES;
47     input_dims.w = DW_INT16XINT8_INPUT_W;
48     input_dims.h = DW_INT16XINT8_INPUT_H;
49     input_dims.c = DW_INT16XINT8_IN_CH;
50     filter_dims.w = DW_INT16XINT8_FILTER_X;
51     filter_dims.h = DW_INT16XINT8_FILTER_Y;
52     output_dims.w = DW_INT16XINT8_OUTPUT_W;
53     output_dims.h = DW_INT16XINT8_OUTPUT_H;
54     output_dims.c = DW_INT16XINT8_OUT_CH;
55 
56     dw_conv_params.padding.w = DW_INT16XINT8_PAD_X;
57     dw_conv_params.padding.h = DW_INT16XINT8_PAD_Y;
58     dw_conv_params.stride.w = DW_INT16XINT8_STRIDE_X;
59     dw_conv_params.stride.h = DW_INT16XINT8_STRIDE_Y;
60     dw_conv_params.dilation.w = DW_INT16XINT8_DILATION_X;
61     dw_conv_params.dilation.h = DW_INT16XINT8_DILATION_Y;
62 
63     dw_conv_params.ch_mult = DW_INT16XINT8_CH_MULT;
64 
65     dw_conv_params.input_offset = DW_INT16XINT8_INPUT_OFFSET;
66     dw_conv_params.output_offset = DW_INT16XINT8_OUTPUT_OFFSET;
67     dw_conv_params.activation.min = DW_INT16XINT8_OUT_ACTIVATION_MIN;
68     dw_conv_params.activation.max = DW_INT16XINT8_OUT_ACTIVATION_MAX;
69     quant_params.multiplier = (int32_t *)dw_int16xint8_output_mult;
70     quant_params.shift = (int32_t *)dw_int16xint8_output_shift;
71 
72     ctx.buf = NULL;
73     ctx.size = 0;
74 
75     arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
76                                                         &dw_conv_params,
77                                                         &quant_params,
78                                                         &input_dims,
79                                                         input_data,
80                                                         &filter_dims,
81                                                         dw_int16xint8_weights,
82                                                         &bias_dims,
83                                                         bias_data,
84                                                         &output_dims,
85                                                         output);
86     if (ctx.buf)
87     {
88         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
89         memset(ctx.buf, 0, ctx.size);
90         free(ctx.buf);
91     }
92     TEST_ASSERT_EQUAL(expected, result);
93     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
94     memset(output, 0, sizeof(output));
95 
96     int buf_size =
97         arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
98 
99     TEST_ASSERT_EQUAL(buf_size, 0);
100 
101     ctx.buf = malloc(buf_size);
102 
103     result = arm_depthwise_conv_wrapper_s16(&ctx,
104                                             &dw_conv_params,
105                                             &quant_params,
106                                             &input_dims,
107                                             input_data,
108                                             &filter_dims,
109                                             kernel_data,
110                                             &bias_dims,
111                                             bias_data,
112                                             &output_dims,
113                                             output);
114 
115     if (ctx.buf)
116     {
117         memset(ctx.buf, 0, buf_size);
118         free(ctx.buf);
119     }
120     TEST_ASSERT_EQUAL(expected, result);
121     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
122 }
123 
dw_int16xint8_dilation_arm_depthwise_conv_s16(void)124 void dw_int16xint8_dilation_arm_depthwise_conv_s16(void)
125 {
126     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
127     int16_t output[DW_INT16XINT8_DILATION_DST_SIZE] = {0};
128 
129     cmsis_nn_context ctx;
130     cmsis_nn_dw_conv_params dw_conv_params;
131     cmsis_nn_per_channel_quant_params quant_params;
132     cmsis_nn_dims input_dims;
133     cmsis_nn_dims filter_dims;
134     cmsis_nn_dims bias_dims = {};
135     cmsis_nn_dims output_dims;
136 
137     const int64_t *bias_data = dw_int16xint8_dilation_biases;
138     const int16_t *input_data = dw_int16xint8_dilation_input;
139     const int8_t *kernel_data = dw_int16xint8_dilation_weights;
140     const int16_t *output_ref = dw_int16xint8_dilation_output_ref;
141     const int32_t output_ref_size = DW_INT16XINT8_DILATION_DST_SIZE;
142 
143     input_dims.n = DW_INT16XINT8_DILATION_INPUT_BATCHES;
144     input_dims.w = DW_INT16XINT8_DILATION_INPUT_W;
145     input_dims.h = DW_INT16XINT8_DILATION_INPUT_H;
146     input_dims.c = DW_INT16XINT8_DILATION_IN_CH;
147     filter_dims.w = DW_INT16XINT8_DILATION_FILTER_X;
148     filter_dims.h = DW_INT16XINT8_DILATION_FILTER_Y;
149     output_dims.w = DW_INT16XINT8_DILATION_OUTPUT_W;
150     output_dims.h = DW_INT16XINT8_DILATION_OUTPUT_H;
151     output_dims.c = DW_INT16XINT8_DILATION_OUT_CH;
152 
153     dw_conv_params.padding.w = DW_INT16XINT8_DILATION_PAD_X;
154     dw_conv_params.padding.h = DW_INT16XINT8_DILATION_PAD_Y;
155     dw_conv_params.stride.w = DW_INT16XINT8_DILATION_STRIDE_X;
156     dw_conv_params.stride.h = DW_INT16XINT8_DILATION_STRIDE_Y;
157     dw_conv_params.dilation.w = DW_INT16XINT8_DILATION_DILATION_X;
158     dw_conv_params.dilation.h = DW_INT16XINT8_DILATION_DILATION_Y;
159 
160     dw_conv_params.ch_mult = DW_INT16XINT8_DILATION_CH_MULT;
161 
162     dw_conv_params.input_offset = DW_INT16XINT8_DILATION_INPUT_OFFSET;
163     dw_conv_params.output_offset = DW_INT16XINT8_DILATION_OUTPUT_OFFSET;
164     dw_conv_params.activation.min = DW_INT16XINT8_DILATION_OUT_ACTIVATION_MIN;
165     dw_conv_params.activation.max = DW_INT16XINT8_DILATION_OUT_ACTIVATION_MAX;
166     quant_params.multiplier = (int32_t *)dw_int16xint8_dilation_output_mult;
167     quant_params.shift = (int32_t *)dw_int16xint8_dilation_output_shift;
168 
169     ctx.buf = NULL;
170     ctx.size = 0;
171 
172     arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
173                                                         &dw_conv_params,
174                                                         &quant_params,
175                                                         &input_dims,
176                                                         input_data,
177                                                         &filter_dims,
178                                                         dw_int16xint8_dilation_weights,
179                                                         &bias_dims,
180                                                         bias_data,
181                                                         &output_dims,
182                                                         output);
183 
184     if (ctx.buf)
185     {
186         memset(ctx.buf, 0, ctx.size);
187         free(ctx.buf);
188     }
189     TEST_ASSERT_EQUAL(expected, result);
190     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
191     memset(output, 0, sizeof(output));
192 
193     int buf_size =
194         arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
195 
196     TEST_ASSERT_EQUAL(buf_size, 0);
197 
198     ctx.buf = malloc(buf_size);
199 
200     result = arm_depthwise_conv_wrapper_s16(&ctx,
201                                             &dw_conv_params,
202                                             &quant_params,
203                                             &input_dims,
204                                             input_data,
205                                             &filter_dims,
206                                             kernel_data,
207                                             &bias_dims,
208                                             bias_data,
209                                             &output_dims,
210                                             output);
211 
212     if (ctx.buf)
213     {
214         memset(ctx.buf, 0, buf_size);
215         free(ctx.buf);
216     }
217     TEST_ASSERT_EQUAL(expected, result);
218     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
219 }
220 
dw_int16xint8_mult4_arm_depthwise_conv_s16(void)221 void dw_int16xint8_mult4_arm_depthwise_conv_s16(void)
222 {
223     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
224     int16_t output[DW_INT16XINT8_MULT4_DST_SIZE] = {0};
225 
226     cmsis_nn_context ctx;
227     cmsis_nn_dw_conv_params dw_conv_params;
228     cmsis_nn_per_channel_quant_params quant_params;
229     cmsis_nn_dims input_dims;
230     cmsis_nn_dims filter_dims;
231     cmsis_nn_dims bias_dims = {};
232     cmsis_nn_dims output_dims;
233 
234     const int64_t *bias_data = dw_int16xint8_mult4_biases;
235     const int16_t *input_data = dw_int16xint8_mult4_input;
236     const int8_t *kernel_data = dw_int16xint8_mult4_weights;
237     const int16_t *output_ref = dw_int16xint8_mult4_output_ref;
238     const int32_t output_ref_size = DW_INT16XINT8_MULT4_DST_SIZE;
239 
240     input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
241     input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
242     input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
243     input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
244     filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
245     filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
246     output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
247     output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
248     output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
249 
250     dw_conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
251     dw_conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
252     dw_conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
253     dw_conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
254     dw_conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
255     dw_conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
256 
257     dw_conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
258 
259     dw_conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
260     dw_conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
261     dw_conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
262     dw_conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
263     quant_params.multiplier = (int32_t *)dw_int16xint8_mult4_output_mult;
264     quant_params.shift = (int32_t *)dw_int16xint8_mult4_output_shift;
265 
266     ctx.buf = NULL;
267     ctx.size = 0;
268 
269     arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
270                                                         &dw_conv_params,
271                                                         &quant_params,
272                                                         &input_dims,
273                                                         input_data,
274                                                         &filter_dims,
275                                                         dw_int16xint8_mult4_weights,
276                                                         &bias_dims,
277                                                         bias_data,
278                                                         &output_dims,
279                                                         output);
280 
281     if (ctx.buf)
282     {
283         memset(ctx.buf, 0, ctx.size);
284         free(ctx.buf);
285     }
286     TEST_ASSERT_EQUAL(expected, result);
287     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
288     memset(output, 0, sizeof(output));
289 
290     int buf_size =
291         arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
292 
293     TEST_ASSERT_EQUAL(buf_size, 0);
294 
295     ctx.buf = malloc(buf_size);
296 
297     result = arm_depthwise_conv_wrapper_s16(&ctx,
298                                             &dw_conv_params,
299                                             &quant_params,
300                                             &input_dims,
301                                             input_data,
302                                             &filter_dims,
303                                             kernel_data,
304                                             &bias_dims,
305                                             bias_data,
306                                             &output_dims,
307                                             output);
308 
309     if (ctx.buf)
310     {
311         memset(ctx.buf, 0, buf_size);
312         free(ctx.buf);
313     }
314     TEST_ASSERT_EQUAL(expected, result);
315     TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
316 }
317 
arm_depthwise_conv_wrapper_s16_buffer(void)318 void arm_depthwise_conv_wrapper_s16_buffer(void)
319 {
320     cmsis_nn_dims input_dims;
321     cmsis_nn_dims filter_dims;
322     cmsis_nn_dims output_dims;
323 
324     cmsis_nn_dw_conv_params dw_conv_params;
325     input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
326     input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
327     input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
328     input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
329     filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
330     filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
331 
332     output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
333     output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
334     output_dims.c = input_dims.c;
335 
336     dw_conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
337     dw_conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
338     dw_conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
339     dw_conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
340     dw_conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
341     dw_conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
342     dw_conv_params.ch_mult = output_dims.c / input_dims.c;
343 
344     int32_t size =
345         arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
346 
347 #if defined(ARM_MATH_DSP)
348     TEST_ASSERT_TRUE(size > 0);
349 #else
350     TEST_ASSERT_TRUE(size == 0);
351 #endif
352     input_dims.c = 513;
353     output_dims.c = input_dims.c;
354     dw_conv_params.ch_mult = output_dims.c / input_dims.c;
355     size = arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
356 
357 #if defined(ARM_MATH_DSP)
358     TEST_ASSERT_TRUE(size > 0);
359 #else
360     TEST_ASSERT_TRUE(size == 0);
361 #endif
362 }
363 
buffer_size_mve_arm_depthwise_conv_s16(void)364 void buffer_size_mve_arm_depthwise_conv_s16(void)
365 {
366 #if defined(ARM_MATH_MVEI)
367     cmsis_nn_dw_conv_params conv_params;
368     cmsis_nn_dims input_dims;
369     cmsis_nn_dims filter_dims;
370     cmsis_nn_dims output_dims;
371 
372     input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
373     input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
374     input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
375     input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
376     filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
377     filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
378     output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
379     output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
380     output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
381 
382     conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
383     conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
384     conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
385     conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
386     conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
387     conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
388     conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
389     conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
390     conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
391     conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
392     conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
393 
394     const int32_t wrapper_buf_size =
395         arm_depthwise_conv_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
396     const int32_t mve_wrapper_buf_size =
397         arm_depthwise_conv_wrapper_s16_get_buffer_size_mve(&conv_params, &input_dims, &filter_dims, &output_dims);
398 
399     TEST_ASSERT_EQUAL(wrapper_buf_size, mve_wrapper_buf_size);
400 #endif
401 }
402 
buffer_size_dsp_arm_depthwise_conv_s16(void)403 void buffer_size_dsp_arm_depthwise_conv_s16(void)
404 {
405 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
406     cmsis_nn_dw_conv_params conv_params;
407     cmsis_nn_dims input_dims;
408     cmsis_nn_dims filter_dims;
409     cmsis_nn_dims output_dims;
410 
411     input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
412     input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
413     input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
414     input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
415     filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
416     filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
417     output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
418     output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
419     output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
420 
421     conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
422     conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
423     conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
424     conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
425     conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
426     conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
427 
428     conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
429 
430     conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
431     conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
432     conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
433     conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
434 
435     const int32_t wrapper_buf_size =
436         arm_depthwise_conv_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
437     const int32_t dsp_wrapper_buf_size =
438         arm_depthwise_conv_wrapper_s16_get_buffer_size_dsp(&conv_params, &input_dims, &filter_dims, &output_dims);
439 
440     TEST_ASSERT_EQUAL(wrapper_buf_size, dsp_wrapper_buf_size);
441 #endif
442 }
443