1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include <arm_nnfunctions.h>
20 #include <unity.h>
21
22 #include "../TestData/dw_int16xint8/test_data.h"
23 #include "../TestData/dw_int16xint8_dilation/test_data.h"
24 #include "../TestData/dw_int16xint8_mult4/test_data.h"
25 #include "../Utils/validate.h"
26
dw_int16xint8_arm_depthwise_conv_s16(void)27 void dw_int16xint8_arm_depthwise_conv_s16(void)
28 {
29 const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
30 int16_t output[DW_INT16XINT8_DST_SIZE] = {0};
31
32 cmsis_nn_context ctx;
33 cmsis_nn_dw_conv_params dw_conv_params;
34 cmsis_nn_per_channel_quant_params quant_params;
35 cmsis_nn_dims input_dims;
36 cmsis_nn_dims filter_dims;
37 cmsis_nn_dims bias_dims = {};
38 cmsis_nn_dims output_dims;
39
40 const int64_t *bias_data = dw_int16xint8_biases;
41 const int16_t *input_data = dw_int16xint8_input;
42 const int8_t *kernel_data = dw_int16xint8_weights;
43 const int16_t *output_ref = dw_int16xint8_output_ref;
44 const int32_t output_ref_size = DW_INT16XINT8_DST_SIZE;
45
46 input_dims.n = DW_INT16XINT8_INPUT_BATCHES;
47 input_dims.w = DW_INT16XINT8_INPUT_W;
48 input_dims.h = DW_INT16XINT8_INPUT_H;
49 input_dims.c = DW_INT16XINT8_IN_CH;
50 filter_dims.w = DW_INT16XINT8_FILTER_X;
51 filter_dims.h = DW_INT16XINT8_FILTER_Y;
52 output_dims.w = DW_INT16XINT8_OUTPUT_W;
53 output_dims.h = DW_INT16XINT8_OUTPUT_H;
54 output_dims.c = DW_INT16XINT8_OUT_CH;
55
56 dw_conv_params.padding.w = DW_INT16XINT8_PAD_X;
57 dw_conv_params.padding.h = DW_INT16XINT8_PAD_Y;
58 dw_conv_params.stride.w = DW_INT16XINT8_STRIDE_X;
59 dw_conv_params.stride.h = DW_INT16XINT8_STRIDE_Y;
60 dw_conv_params.dilation.w = DW_INT16XINT8_DILATION_X;
61 dw_conv_params.dilation.h = DW_INT16XINT8_DILATION_Y;
62
63 dw_conv_params.ch_mult = DW_INT16XINT8_CH_MULT;
64
65 dw_conv_params.input_offset = DW_INT16XINT8_INPUT_OFFSET;
66 dw_conv_params.output_offset = DW_INT16XINT8_OUTPUT_OFFSET;
67 dw_conv_params.activation.min = DW_INT16XINT8_OUT_ACTIVATION_MIN;
68 dw_conv_params.activation.max = DW_INT16XINT8_OUT_ACTIVATION_MAX;
69 quant_params.multiplier = (int32_t *)dw_int16xint8_output_mult;
70 quant_params.shift = (int32_t *)dw_int16xint8_output_shift;
71
72 ctx.buf = NULL;
73 ctx.size = 0;
74
75 arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
76 &dw_conv_params,
77 &quant_params,
78 &input_dims,
79 input_data,
80 &filter_dims,
81 dw_int16xint8_weights,
82 &bias_dims,
83 bias_data,
84 &output_dims,
85 output);
86 if (ctx.buf)
87 {
88 // The caller is responsible to clear the scratch buffers for security reasons if applicable.
89 memset(ctx.buf, 0, ctx.size);
90 free(ctx.buf);
91 }
92 TEST_ASSERT_EQUAL(expected, result);
93 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
94 memset(output, 0, sizeof(output));
95
96 int buf_size =
97 arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
98
99 TEST_ASSERT_EQUAL(buf_size, 0);
100
101 ctx.buf = malloc(buf_size);
102
103 result = arm_depthwise_conv_wrapper_s16(&ctx,
104 &dw_conv_params,
105 &quant_params,
106 &input_dims,
107 input_data,
108 &filter_dims,
109 kernel_data,
110 &bias_dims,
111 bias_data,
112 &output_dims,
113 output);
114
115 if (ctx.buf)
116 {
117 memset(ctx.buf, 0, buf_size);
118 free(ctx.buf);
119 }
120 TEST_ASSERT_EQUAL(expected, result);
121 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
122 }
123
dw_int16xint8_dilation_arm_depthwise_conv_s16(void)124 void dw_int16xint8_dilation_arm_depthwise_conv_s16(void)
125 {
126 const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
127 int16_t output[DW_INT16XINT8_DILATION_DST_SIZE] = {0};
128
129 cmsis_nn_context ctx;
130 cmsis_nn_dw_conv_params dw_conv_params;
131 cmsis_nn_per_channel_quant_params quant_params;
132 cmsis_nn_dims input_dims;
133 cmsis_nn_dims filter_dims;
134 cmsis_nn_dims bias_dims = {};
135 cmsis_nn_dims output_dims;
136
137 const int64_t *bias_data = dw_int16xint8_dilation_biases;
138 const int16_t *input_data = dw_int16xint8_dilation_input;
139 const int8_t *kernel_data = dw_int16xint8_dilation_weights;
140 const int16_t *output_ref = dw_int16xint8_dilation_output_ref;
141 const int32_t output_ref_size = DW_INT16XINT8_DILATION_DST_SIZE;
142
143 input_dims.n = DW_INT16XINT8_DILATION_INPUT_BATCHES;
144 input_dims.w = DW_INT16XINT8_DILATION_INPUT_W;
145 input_dims.h = DW_INT16XINT8_DILATION_INPUT_H;
146 input_dims.c = DW_INT16XINT8_DILATION_IN_CH;
147 filter_dims.w = DW_INT16XINT8_DILATION_FILTER_X;
148 filter_dims.h = DW_INT16XINT8_DILATION_FILTER_Y;
149 output_dims.w = DW_INT16XINT8_DILATION_OUTPUT_W;
150 output_dims.h = DW_INT16XINT8_DILATION_OUTPUT_H;
151 output_dims.c = DW_INT16XINT8_DILATION_OUT_CH;
152
153 dw_conv_params.padding.w = DW_INT16XINT8_DILATION_PAD_X;
154 dw_conv_params.padding.h = DW_INT16XINT8_DILATION_PAD_Y;
155 dw_conv_params.stride.w = DW_INT16XINT8_DILATION_STRIDE_X;
156 dw_conv_params.stride.h = DW_INT16XINT8_DILATION_STRIDE_Y;
157 dw_conv_params.dilation.w = DW_INT16XINT8_DILATION_DILATION_X;
158 dw_conv_params.dilation.h = DW_INT16XINT8_DILATION_DILATION_Y;
159
160 dw_conv_params.ch_mult = DW_INT16XINT8_DILATION_CH_MULT;
161
162 dw_conv_params.input_offset = DW_INT16XINT8_DILATION_INPUT_OFFSET;
163 dw_conv_params.output_offset = DW_INT16XINT8_DILATION_OUTPUT_OFFSET;
164 dw_conv_params.activation.min = DW_INT16XINT8_DILATION_OUT_ACTIVATION_MIN;
165 dw_conv_params.activation.max = DW_INT16XINT8_DILATION_OUT_ACTIVATION_MAX;
166 quant_params.multiplier = (int32_t *)dw_int16xint8_dilation_output_mult;
167 quant_params.shift = (int32_t *)dw_int16xint8_dilation_output_shift;
168
169 ctx.buf = NULL;
170 ctx.size = 0;
171
172 arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
173 &dw_conv_params,
174 &quant_params,
175 &input_dims,
176 input_data,
177 &filter_dims,
178 dw_int16xint8_dilation_weights,
179 &bias_dims,
180 bias_data,
181 &output_dims,
182 output);
183
184 if (ctx.buf)
185 {
186 memset(ctx.buf, 0, ctx.size);
187 free(ctx.buf);
188 }
189 TEST_ASSERT_EQUAL(expected, result);
190 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
191 memset(output, 0, sizeof(output));
192
193 int buf_size =
194 arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
195
196 TEST_ASSERT_EQUAL(buf_size, 0);
197
198 ctx.buf = malloc(buf_size);
199
200 result = arm_depthwise_conv_wrapper_s16(&ctx,
201 &dw_conv_params,
202 &quant_params,
203 &input_dims,
204 input_data,
205 &filter_dims,
206 kernel_data,
207 &bias_dims,
208 bias_data,
209 &output_dims,
210 output);
211
212 if (ctx.buf)
213 {
214 memset(ctx.buf, 0, buf_size);
215 free(ctx.buf);
216 }
217 TEST_ASSERT_EQUAL(expected, result);
218 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
219 }
220
dw_int16xint8_mult4_arm_depthwise_conv_s16(void)221 void dw_int16xint8_mult4_arm_depthwise_conv_s16(void)
222 {
223 const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
224 int16_t output[DW_INT16XINT8_MULT4_DST_SIZE] = {0};
225
226 cmsis_nn_context ctx;
227 cmsis_nn_dw_conv_params dw_conv_params;
228 cmsis_nn_per_channel_quant_params quant_params;
229 cmsis_nn_dims input_dims;
230 cmsis_nn_dims filter_dims;
231 cmsis_nn_dims bias_dims = {};
232 cmsis_nn_dims output_dims;
233
234 const int64_t *bias_data = dw_int16xint8_mult4_biases;
235 const int16_t *input_data = dw_int16xint8_mult4_input;
236 const int8_t *kernel_data = dw_int16xint8_mult4_weights;
237 const int16_t *output_ref = dw_int16xint8_mult4_output_ref;
238 const int32_t output_ref_size = DW_INT16XINT8_MULT4_DST_SIZE;
239
240 input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
241 input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
242 input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
243 input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
244 filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
245 filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
246 output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
247 output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
248 output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
249
250 dw_conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
251 dw_conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
252 dw_conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
253 dw_conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
254 dw_conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
255 dw_conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
256
257 dw_conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
258
259 dw_conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
260 dw_conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
261 dw_conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
262 dw_conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
263 quant_params.multiplier = (int32_t *)dw_int16xint8_mult4_output_mult;
264 quant_params.shift = (int32_t *)dw_int16xint8_mult4_output_shift;
265
266 ctx.buf = NULL;
267 ctx.size = 0;
268
269 arm_cmsis_nn_status result = arm_depthwise_conv_s16(&ctx,
270 &dw_conv_params,
271 &quant_params,
272 &input_dims,
273 input_data,
274 &filter_dims,
275 dw_int16xint8_mult4_weights,
276 &bias_dims,
277 bias_data,
278 &output_dims,
279 output);
280
281 if (ctx.buf)
282 {
283 memset(ctx.buf, 0, ctx.size);
284 free(ctx.buf);
285 }
286 TEST_ASSERT_EQUAL(expected, result);
287 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
288 memset(output, 0, sizeof(output));
289
290 int buf_size =
291 arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
292
293 TEST_ASSERT_EQUAL(buf_size, 0);
294
295 ctx.buf = malloc(buf_size);
296
297 result = arm_depthwise_conv_wrapper_s16(&ctx,
298 &dw_conv_params,
299 &quant_params,
300 &input_dims,
301 input_data,
302 &filter_dims,
303 kernel_data,
304 &bias_dims,
305 bias_data,
306 &output_dims,
307 output);
308
309 if (ctx.buf)
310 {
311 memset(ctx.buf, 0, buf_size);
312 free(ctx.buf);
313 }
314 TEST_ASSERT_EQUAL(expected, result);
315 TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size));
316 }
317
arm_depthwise_conv_wrapper_s16_buffer(void)318 void arm_depthwise_conv_wrapper_s16_buffer(void)
319 {
320 cmsis_nn_dims input_dims;
321 cmsis_nn_dims filter_dims;
322 cmsis_nn_dims output_dims;
323
324 cmsis_nn_dw_conv_params dw_conv_params;
325 input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
326 input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
327 input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
328 input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
329 filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
330 filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
331
332 output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
333 output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
334 output_dims.c = input_dims.c;
335
336 dw_conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
337 dw_conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
338 dw_conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
339 dw_conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
340 dw_conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
341 dw_conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
342 dw_conv_params.ch_mult = output_dims.c / input_dims.c;
343
344 int32_t size =
345 arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
346
347 #if defined(ARM_MATH_DSP)
348 TEST_ASSERT_TRUE(size > 0);
349 #else
350 TEST_ASSERT_TRUE(size == 0);
351 #endif
352 input_dims.c = 513;
353 output_dims.c = input_dims.c;
354 dw_conv_params.ch_mult = output_dims.c / input_dims.c;
355 size = arm_depthwise_conv_wrapper_s16_get_buffer_size(&dw_conv_params, &input_dims, &filter_dims, &output_dims);
356
357 #if defined(ARM_MATH_DSP)
358 TEST_ASSERT_TRUE(size > 0);
359 #else
360 TEST_ASSERT_TRUE(size == 0);
361 #endif
362 }
363
buffer_size_mve_arm_depthwise_conv_s16(void)364 void buffer_size_mve_arm_depthwise_conv_s16(void)
365 {
366 #if defined(ARM_MATH_MVEI)
367 cmsis_nn_dw_conv_params conv_params;
368 cmsis_nn_dims input_dims;
369 cmsis_nn_dims filter_dims;
370 cmsis_nn_dims output_dims;
371
372 input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
373 input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
374 input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
375 input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
376 filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
377 filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
378 output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
379 output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
380 output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
381
382 conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
383 conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
384 conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
385 conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
386 conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
387 conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
388 conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
389 conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
390 conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
391 conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
392 conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
393
394 const int32_t wrapper_buf_size =
395 arm_depthwise_conv_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
396 const int32_t mve_wrapper_buf_size =
397 arm_depthwise_conv_wrapper_s16_get_buffer_size_mve(&conv_params, &input_dims, &filter_dims, &output_dims);
398
399 TEST_ASSERT_EQUAL(wrapper_buf_size, mve_wrapper_buf_size);
400 #endif
401 }
402
buffer_size_dsp_arm_depthwise_conv_s16(void)403 void buffer_size_dsp_arm_depthwise_conv_s16(void)
404 {
405 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
406 cmsis_nn_dw_conv_params conv_params;
407 cmsis_nn_dims input_dims;
408 cmsis_nn_dims filter_dims;
409 cmsis_nn_dims output_dims;
410
411 input_dims.n = DW_INT16XINT8_MULT4_INPUT_BATCHES;
412 input_dims.w = DW_INT16XINT8_MULT4_INPUT_W;
413 input_dims.h = DW_INT16XINT8_MULT4_INPUT_H;
414 input_dims.c = DW_INT16XINT8_MULT4_IN_CH;
415 filter_dims.w = DW_INT16XINT8_MULT4_FILTER_X;
416 filter_dims.h = DW_INT16XINT8_MULT4_FILTER_Y;
417 output_dims.w = DW_INT16XINT8_MULT4_OUTPUT_W;
418 output_dims.h = DW_INT16XINT8_MULT4_OUTPUT_H;
419 output_dims.c = DW_INT16XINT8_MULT4_OUT_CH;
420
421 conv_params.padding.w = DW_INT16XINT8_MULT4_PAD_X;
422 conv_params.padding.h = DW_INT16XINT8_MULT4_PAD_Y;
423 conv_params.stride.w = DW_INT16XINT8_MULT4_STRIDE_X;
424 conv_params.stride.h = DW_INT16XINT8_MULT4_STRIDE_Y;
425 conv_params.dilation.w = DW_INT16XINT8_MULT4_DILATION_X;
426 conv_params.dilation.h = DW_INT16XINT8_MULT4_DILATION_Y;
427
428 conv_params.ch_mult = DW_INT16XINT8_MULT4_CH_MULT;
429
430 conv_params.input_offset = DW_INT16XINT8_MULT4_INPUT_OFFSET;
431 conv_params.output_offset = DW_INT16XINT8_MULT4_OUTPUT_OFFSET;
432 conv_params.activation.min = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MIN;
433 conv_params.activation.max = DW_INT16XINT8_MULT4_OUT_ACTIVATION_MAX;
434
435 const int32_t wrapper_buf_size =
436 arm_depthwise_conv_wrapper_s16_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
437 const int32_t dsp_wrapper_buf_size =
438 arm_depthwise_conv_wrapper_s16_get_buffer_size_dsp(&conv_params, &input_dims, &filter_dims, &output_dims);
439
440 TEST_ASSERT_EQUAL(wrapper_buf_size, dsp_wrapper_buf_size);
441 #endif
442 }
443