1 /*
2 * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_s4_s16.c
22 * Description: Matrix-multiplication function for convolution
23 *
24 * $Date: 01 November 2023
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnsupportfunctions.h"
31
32 /*
33 * Matrix-multiplication function for convolution with per-channel requantization and 4bit weights.
34 *
35 * Refer header file for details.
36 *
37 */
38
arm_nn_mat_mult_kernel_s4_s16(const int8_t * packed_input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int32_t activation_min,const int32_t activation_max,const int32_t num_col_a,const int32_t * const output_bias,int8_t * out_0)39 int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *packed_input_a,
40 const int16_t *input_b,
41 const uint16_t output_ch,
42 const int32_t *out_shift,
43 const int32_t *out_mult,
44 const int32_t out_offset,
45 const int32_t activation_min,
46 const int32_t activation_max,
47 const int32_t num_col_a,
48 const int32_t *const output_bias,
49 int8_t *out_0)
50 {
51
52 /* set up the second output pointers */
53 int8_t *out_1 = out_0 + output_ch;
54 const int32_t *bias = output_bias;
55
56 uint16_t row_count = output_ch / 4;
57 const int8_t *packed_ip_a0 = packed_input_a;
58 /* this loop over rows in A */
59 while (row_count)
60 {
61 int8_t spillover0 = 0;
62 int8_t spillover1 = 0;
63 /* setup pointers for B */
64 const int16_t *ip_b0 = input_b;
65 const int16_t *ip_b1 = ip_b0 + num_col_a;
66
67 /* Align the second pointer for A.
68 * This will skip a row so that we can ensure the that spilled rows
69 * don't offset the symmetry.
70 */
71 const int8_t *packed_ip_a1 = packed_ip_a0 + num_col_a;
72
73 int32_t ch_0_out_0 = 0;
74 int32_t ch_0_out_1 = 0;
75 int32_t ch_1_out_0 = 0;
76 int32_t ch_1_out_1 = 0;
77 /* Init accumulator with bias for channel N and N + 1 */
78 if (bias)
79 {
80 ch_0_out_0 = *bias;
81 ch_0_out_1 = *bias;
82 bias += 2;
83 ch_1_out_0 = *bias;
84 ch_1_out_1 = *bias--;
85 }
86
87 #if defined(ARM_MATH_DSP)
88 int32_t col_count = num_col_a / 4;
89 /* accumulate over the vector */
90
91 while (col_count)
92 {
93 int32_t a01, a02, a11, a12;
94 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
95 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
96
97 read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
98 read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
99 packed_ip_a0 += 2;
100 packed_ip_a1 += 2;
101
102 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
103 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
104 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
105 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
106
107 b0 = arm_nn_read_q15x2_ia(&ip_b0);
108 b1 = arm_nn_read_q15x2_ia(&ip_b1);
109
110 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
111 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
112 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
113 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
114
115 col_count--;
116 } /* while over col_count */
117 col_count = (num_col_a & 0x3) >> 1;
118 #else
119 int32_t col_count = num_col_a >> 1;
120 #endif
121 while (col_count)
122 {
123 int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
124 int8_t higher_a0 = packed_ip_a0[0] >> 4;
125 int16_t b0 = *ip_b0++;
126
127 int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
128 int8_t higher_a1 = packed_ip_a1[0] >> 4;
129 int16_t b1 = *ip_b1++;
130
131 packed_ip_a0++;
132 packed_ip_a1++;
133
134 ch_0_out_0 += lower_a0 * b0;
135 ch_0_out_1 += lower_a0 * b1;
136 ch_1_out_0 += lower_a1 * b0;
137 ch_1_out_1 += lower_a1 * b1;
138
139 b0 = *ip_b0++;
140 b1 = *ip_b1++;
141
142 ch_0_out_0 += higher_a0 * b0;
143 ch_0_out_1 += higher_a0 * b1;
144 ch_1_out_0 += higher_a1 * b0;
145 ch_1_out_1 += higher_a1 * b1;
146
147 col_count--;
148 } /* while over col_count */
149 /* left over column */
150 if (num_col_a % 2)
151 {
152 int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
153 spillover0 = packed_ip_a0[0] >> 4;
154 int16_t b0 = *ip_b0++;
155
156 int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
157 spillover1 = packed_ip_a1[0] >> 4;
158 int16_t b1 = *ip_b1++;
159
160 packed_ip_a0++;
161 packed_ip_a1++;
162
163 ch_0_out_0 += lower_a0 * b0;
164 ch_0_out_1 += lower_a0 * b1;
165 ch_1_out_0 += lower_a1 * b0;
166 ch_1_out_1 += lower_a1 * b1;
167 }
168
169 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
170 ch_0_out_0 += out_offset;
171 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
172 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
173 *out_0 = (int8_t)ch_0_out_0;
174 out_0 += 2;
175
176 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
177 ch_0_out_1 += out_offset;
178 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
179 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
180 *out_1 = (int8_t)ch_0_out_1;
181 out_1 += 2;
182 out_mult += 2;
183 out_shift += 2;
184
185 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
186 ch_1_out_0 += out_offset;
187 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
188 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
189 *out_0-- = (int8_t)ch_1_out_0;
190
191 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
192 ch_1_out_1 += out_offset;
193 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
194 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
195 *out_1-- = (int8_t)ch_1_out_1;
196 out_mult--;
197 out_shift--;
198
199 /* setup pointers for B */
200 ip_b0 = input_b;
201 ip_b1 = ip_b0 + num_col_a;
202
203 /* Align the second pointer for A.
204 * This will skip a row so that we can ensure the that spilled rows
205 * don't offset the symmetry.
206 */
207 packed_ip_a1 = packed_ip_a0 + num_col_a;
208
209 ch_0_out_0 = 0;
210 ch_0_out_1 = 0;
211 ch_1_out_0 = 0;
212 ch_1_out_1 = 0;
213 /* Init accumulator with bias for channel N and N + 1 */
214 if (bias)
215 {
216 ch_0_out_0 = *bias;
217 ch_0_out_1 = *bias;
218 bias += 2;
219 ch_1_out_0 = *bias;
220 ch_1_out_1 = *bias++;
221 }
222
223 if (num_col_a % 2)
224 {
225 int16_t b0 = *ip_b0++;
226 int16_t b1 = *ip_b1++;
227
228 ch_0_out_0 += spillover0 * b0;
229 ch_0_out_1 += spillover0 * b1;
230 ch_1_out_0 += spillover1 * b0;
231 ch_1_out_1 += spillover1 * b1;
232 }
233
234 #if defined(ARM_MATH_DSP)
235 col_count = num_col_a / 4;
236 /* accumulate over the vector */
237 while (col_count)
238 {
239 int32_t a01, a02, a11, a12;
240 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
241 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
242
243 read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
244 read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
245 packed_ip_a0 += 2;
246 packed_ip_a1 += 2;
247
248 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
249 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
250 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
251 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
252
253 b0 = arm_nn_read_q15x2_ia(&ip_b0);
254 b1 = arm_nn_read_q15x2_ia(&ip_b1);
255
256 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
257 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
258 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
259 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
260
261 col_count--;
262 } /* while over col_count */
263 col_count = (num_col_a & 0x3) >> 1;
264 #else
265 col_count = num_col_a >> 1;
266 #endif
267 while (col_count)
268 {
269 int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
270 int8_t higher_a0 = packed_ip_a0[0] >> 4;
271 int16_t b0 = *ip_b0++;
272
273 int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
274 int8_t higher_a1 = packed_ip_a1[0] >> 4;
275 int16_t b1 = *ip_b1++;
276
277 packed_ip_a0++;
278 packed_ip_a1++;
279
280 ch_0_out_0 += lower_a0 * b0;
281 ch_0_out_1 += lower_a0 * b1;
282 ch_1_out_0 += lower_a1 * b0;
283 ch_1_out_1 += lower_a1 * b1;
284
285 b0 = *ip_b0++;
286 b1 = *ip_b1++;
287
288 ch_0_out_0 += higher_a0 * b0;
289 ch_0_out_1 += higher_a0 * b1;
290 ch_1_out_0 += higher_a1 * b0;
291 ch_1_out_1 += higher_a1 * b1;
292
293 col_count--;
294 } /* while over col_count */
295
296 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
297 ch_0_out_0 += out_offset;
298 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
299 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
300 *out_0 = (int8_t)ch_0_out_0;
301 out_0 += 2;
302
303 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
304 ch_0_out_1 += out_offset;
305 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
306 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
307 *out_1 = (int8_t)ch_0_out_1;
308 out_1 += 2;
309 out_mult += 2;
310 out_shift += 2;
311
312 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
313 ch_1_out_0 += out_offset;
314 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
315 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
316 *out_0++ = (int8_t)ch_1_out_0;
317
318 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
319 ch_1_out_1 += out_offset;
320 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
321 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
322 *out_1++ = (int8_t)ch_1_out_1;
323 out_mult++;
324 out_shift++;
325
326 /* skip 2 rows */
327 packed_ip_a0 += num_col_a;
328 row_count--;
329 }
330
331 /* compute the 0 - 3 rows if any */
332 int16_t left_over_rows = 0;
333 while (left_over_rows < output_ch % 4)
334 {
335 /* setup pointers for B */
336 const int16_t *ip_b0 = input_b;
337 const int16_t *ip_b1 = ip_b0 + num_col_a;
338
339 int32_t ch_0_out_0 = 0;
340 int32_t ch_0_out_1 = 0;
341
342 /* load the bias */
343 if (bias)
344 {
345 ch_0_out_0 = *bias;
346 ch_0_out_1 = *bias++;
347 }
348
349 if (left_over_rows == 1 && num_col_a % 2)
350 {
351 int16_t b0 = *ip_b0++;
352 int16_t b1 = *ip_b1++;
353 int8_t spilled_column = packed_ip_a0[0] >> 4;
354
355 ++packed_ip_a0;
356
357 ch_0_out_0 += spilled_column * b0;
358 ch_0_out_1 += spilled_column * b1;
359 }
360
361 #if defined(ARM_MATH_DSP)
362 int32_t col_count = num_col_a / 4;
363 while (col_count)
364 {
365 int32_t a01, a02;
366 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
367 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
368
369 read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
370 packed_ip_a0 += 2;
371
372 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
373 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
374
375 b0 = arm_nn_read_q15x2_ia(&ip_b0);
376 b1 = arm_nn_read_q15x2_ia(&ip_b1);
377 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
378 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
379
380 col_count--;
381 }
382 col_count = (num_col_a & 0x3) >> 1;
383
384 #else
385 int32_t col_count = num_col_a >> 1;
386 #endif
387
388 while (col_count)
389 {
390 int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
391 int8_t a1 = packed_ip_a0[0] >> 4;
392 int16_t b0 = *ip_b0++;
393 int16_t b1 = *ip_b1++;
394
395 ++packed_ip_a0;
396
397 ch_0_out_0 += a0 * b0;
398 ch_0_out_1 += a0 * b1;
399
400 b0 = *ip_b0++;
401 b1 = *ip_b1++;
402
403 ch_0_out_0 += a1 * b0;
404 ch_0_out_1 += a1 * b1;
405
406 col_count--;
407 }
408 if (num_col_a % 2 && left_over_rows != 1)
409 {
410 int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
411
412 int16_t b0 = *ip_b0++;
413 int16_t b1 = *ip_b1++;
414
415 ch_0_out_0 += a0 * b0;
416 ch_0_out_1 += a0 * b1;
417 }
418 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
419 ch_0_out_0 += out_offset;
420 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
421 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
422 *out_0++ = (int8_t)ch_0_out_0;
423
424 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
425 ch_0_out_1 += out_offset;
426 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
427 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
428 *out_1++ = (int8_t)ch_0_out_1;
429 out_mult++;
430 out_shift++;
431
432 ++left_over_rows;
433 }
434
435 out_0 += output_ch;
436
437 /* return the new output pointer with offset */
438 return out_0;
439 }
440