1 /*
2  * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_s4_s16.c
22  * Description:  Matrix-multiplication function for convolution
23  *
24  * $Date:        01 November 2023
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnsupportfunctions.h"
31 
32 /*
33  * Matrix-multiplication function for convolution with per-channel requantization and 4bit weights.
34  *
35  * Refer header file for details.
36  *
37  */
38 
arm_nn_mat_mult_kernel_s4_s16(const int8_t * packed_input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int32_t activation_min,const int32_t activation_max,const int32_t num_col_a,const int32_t * const output_bias,int8_t * out_0)39 int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *packed_input_a,
40                                       const int16_t *input_b,
41                                       const uint16_t output_ch,
42                                       const int32_t *out_shift,
43                                       const int32_t *out_mult,
44                                       const int32_t out_offset,
45                                       const int32_t activation_min,
46                                       const int32_t activation_max,
47                                       const int32_t num_col_a,
48                                       const int32_t *const output_bias,
49                                       int8_t *out_0)
50 {
51 
52     /* set up the second output pointers */
53     int8_t *out_1 = out_0 + output_ch;
54     const int32_t *bias = output_bias;
55 
56     uint16_t row_count = output_ch / 4;
57     const int8_t *packed_ip_a0 = packed_input_a;
58     /* this loop over rows in A */
59     while (row_count)
60     {
61         int8_t spillover0 = 0;
62         int8_t spillover1 = 0;
63         /* setup pointers for B */
64         const int16_t *ip_b0 = input_b;
65         const int16_t *ip_b1 = ip_b0 + num_col_a;
66 
67         /* Align the second pointer for A.
68          * This will skip a row so that we can ensure the that spilled rows
69          * don't offset the symmetry.
70          */
71         const int8_t *packed_ip_a1 = packed_ip_a0 + num_col_a;
72 
73         int32_t ch_0_out_0 = 0;
74         int32_t ch_0_out_1 = 0;
75         int32_t ch_1_out_0 = 0;
76         int32_t ch_1_out_1 = 0;
77         /* Init accumulator with bias for channel N and N + 1 */
78         if (bias)
79         {
80             ch_0_out_0 = *bias;
81             ch_0_out_1 = *bias;
82             bias += 2;
83             ch_1_out_0 = *bias;
84             ch_1_out_1 = *bias--;
85         }
86 
87 #if defined(ARM_MATH_DSP)
88         int32_t col_count = num_col_a / 4;
89         /* accumulate over the vector */
90 
91         while (col_count)
92         {
93             int32_t a01, a02, a11, a12;
94             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
95             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
96 
97             read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
98             read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
99             packed_ip_a0 += 2;
100             packed_ip_a1 += 2;
101 
102             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
103             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
104             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
105             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
106 
107             b0 = arm_nn_read_q15x2_ia(&ip_b0);
108             b1 = arm_nn_read_q15x2_ia(&ip_b1);
109 
110             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
111             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
112             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
113             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
114 
115             col_count--;
116         } /* while over col_count */
117         col_count = (num_col_a & 0x3) >> 1;
118 #else
119         int32_t col_count = num_col_a >> 1;
120 #endif
121         while (col_count)
122         {
123             int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
124             int8_t higher_a0 = packed_ip_a0[0] >> 4;
125             int16_t b0 = *ip_b0++;
126 
127             int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
128             int8_t higher_a1 = packed_ip_a1[0] >> 4;
129             int16_t b1 = *ip_b1++;
130 
131             packed_ip_a0++;
132             packed_ip_a1++;
133 
134             ch_0_out_0 += lower_a0 * b0;
135             ch_0_out_1 += lower_a0 * b1;
136             ch_1_out_0 += lower_a1 * b0;
137             ch_1_out_1 += lower_a1 * b1;
138 
139             b0 = *ip_b0++;
140             b1 = *ip_b1++;
141 
142             ch_0_out_0 += higher_a0 * b0;
143             ch_0_out_1 += higher_a0 * b1;
144             ch_1_out_0 += higher_a1 * b0;
145             ch_1_out_1 += higher_a1 * b1;
146 
147             col_count--;
148         } /* while over col_count */
149         /* left over column */
150         if (num_col_a % 2)
151         {
152             int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
153             spillover0 = packed_ip_a0[0] >> 4;
154             int16_t b0 = *ip_b0++;
155 
156             int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
157             spillover1 = packed_ip_a1[0] >> 4;
158             int16_t b1 = *ip_b1++;
159 
160             packed_ip_a0++;
161             packed_ip_a1++;
162 
163             ch_0_out_0 += lower_a0 * b0;
164             ch_0_out_1 += lower_a0 * b1;
165             ch_1_out_0 += lower_a1 * b0;
166             ch_1_out_1 += lower_a1 * b1;
167         }
168 
169         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
170         ch_0_out_0 += out_offset;
171         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
172         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
173         *out_0 = (int8_t)ch_0_out_0;
174         out_0 += 2;
175 
176         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
177         ch_0_out_1 += out_offset;
178         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
179         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
180         *out_1 = (int8_t)ch_0_out_1;
181         out_1 += 2;
182         out_mult += 2;
183         out_shift += 2;
184 
185         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
186         ch_1_out_0 += out_offset;
187         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
188         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
189         *out_0-- = (int8_t)ch_1_out_0;
190 
191         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
192         ch_1_out_1 += out_offset;
193         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
194         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
195         *out_1-- = (int8_t)ch_1_out_1;
196         out_mult--;
197         out_shift--;
198 
199         /* setup pointers for B */
200         ip_b0 = input_b;
201         ip_b1 = ip_b0 + num_col_a;
202 
203         /* Align the second pointer for A.
204          * This will skip a row so that we can ensure the that spilled rows
205          * don't offset the symmetry.
206          */
207         packed_ip_a1 = packed_ip_a0 + num_col_a;
208 
209         ch_0_out_0 = 0;
210         ch_0_out_1 = 0;
211         ch_1_out_0 = 0;
212         ch_1_out_1 = 0;
213         /* Init accumulator with bias for channel N and N + 1 */
214         if (bias)
215         {
216             ch_0_out_0 = *bias;
217             ch_0_out_1 = *bias;
218             bias += 2;
219             ch_1_out_0 = *bias;
220             ch_1_out_1 = *bias++;
221         }
222 
223         if (num_col_a % 2)
224         {
225             int16_t b0 = *ip_b0++;
226             int16_t b1 = *ip_b1++;
227 
228             ch_0_out_0 += spillover0 * b0;
229             ch_0_out_1 += spillover0 * b1;
230             ch_1_out_0 += spillover1 * b0;
231             ch_1_out_1 += spillover1 * b1;
232         }
233 
234 #if defined(ARM_MATH_DSP)
235         col_count = num_col_a / 4;
236         /* accumulate over the vector */
237         while (col_count)
238         {
239             int32_t a01, a02, a11, a12;
240             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
241             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
242 
243             read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
244             read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
245             packed_ip_a0 += 2;
246             packed_ip_a1 += 2;
247 
248             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
249             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
250             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
251             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
252 
253             b0 = arm_nn_read_q15x2_ia(&ip_b0);
254             b1 = arm_nn_read_q15x2_ia(&ip_b1);
255 
256             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
257             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
258             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
259             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
260 
261             col_count--;
262         } /* while over col_count */
263         col_count = (num_col_a & 0x3) >> 1;
264 #else
265         col_count = num_col_a >> 1;
266 #endif
267         while (col_count)
268         {
269             int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
270             int8_t higher_a0 = packed_ip_a0[0] >> 4;
271             int16_t b0 = *ip_b0++;
272 
273             int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
274             int8_t higher_a1 = packed_ip_a1[0] >> 4;
275             int16_t b1 = *ip_b1++;
276 
277             packed_ip_a0++;
278             packed_ip_a1++;
279 
280             ch_0_out_0 += lower_a0 * b0;
281             ch_0_out_1 += lower_a0 * b1;
282             ch_1_out_0 += lower_a1 * b0;
283             ch_1_out_1 += lower_a1 * b1;
284 
285             b0 = *ip_b0++;
286             b1 = *ip_b1++;
287 
288             ch_0_out_0 += higher_a0 * b0;
289             ch_0_out_1 += higher_a0 * b1;
290             ch_1_out_0 += higher_a1 * b0;
291             ch_1_out_1 += higher_a1 * b1;
292 
293             col_count--;
294         } /* while over col_count */
295 
296         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
297         ch_0_out_0 += out_offset;
298         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
299         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
300         *out_0 = (int8_t)ch_0_out_0;
301         out_0 += 2;
302 
303         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
304         ch_0_out_1 += out_offset;
305         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
306         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
307         *out_1 = (int8_t)ch_0_out_1;
308         out_1 += 2;
309         out_mult += 2;
310         out_shift += 2;
311 
312         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
313         ch_1_out_0 += out_offset;
314         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
315         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
316         *out_0++ = (int8_t)ch_1_out_0;
317 
318         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
319         ch_1_out_1 += out_offset;
320         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
321         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
322         *out_1++ = (int8_t)ch_1_out_1;
323         out_mult++;
324         out_shift++;
325 
326         /* skip 2 rows */
327         packed_ip_a0 += num_col_a;
328         row_count--;
329     }
330 
331     /* compute the 0 - 3 rows if any */
332     int16_t left_over_rows = 0;
333     while (left_over_rows < output_ch % 4)
334     {
335         /* setup pointers for B */
336         const int16_t *ip_b0 = input_b;
337         const int16_t *ip_b1 = ip_b0 + num_col_a;
338 
339         int32_t ch_0_out_0 = 0;
340         int32_t ch_0_out_1 = 0;
341 
342         /* load the bias */
343         if (bias)
344         {
345             ch_0_out_0 = *bias;
346             ch_0_out_1 = *bias++;
347         }
348 
349         if (left_over_rows == 1 && num_col_a % 2)
350         {
351             int16_t b0 = *ip_b0++;
352             int16_t b1 = *ip_b1++;
353             int8_t spilled_column = packed_ip_a0[0] >> 4;
354 
355             ++packed_ip_a0;
356 
357             ch_0_out_0 += spilled_column * b0;
358             ch_0_out_1 += spilled_column * b1;
359         }
360 
361 #if defined(ARM_MATH_DSP)
362         int32_t col_count = num_col_a / 4;
363         while (col_count)
364         {
365             int32_t a01, a02;
366             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
367             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
368 
369             read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
370             packed_ip_a0 += 2;
371 
372             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
373             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
374 
375             b0 = arm_nn_read_q15x2_ia(&ip_b0);
376             b1 = arm_nn_read_q15x2_ia(&ip_b1);
377             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
378             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
379 
380             col_count--;
381         }
382         col_count = (num_col_a & 0x3) >> 1;
383 
384 #else
385         int32_t col_count = num_col_a >> 1;
386 #endif
387 
388         while (col_count)
389         {
390             int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
391             int8_t a1 = packed_ip_a0[0] >> 4;
392             int16_t b0 = *ip_b0++;
393             int16_t b1 = *ip_b1++;
394 
395             ++packed_ip_a0;
396 
397             ch_0_out_0 += a0 * b0;
398             ch_0_out_1 += a0 * b1;
399 
400             b0 = *ip_b0++;
401             b1 = *ip_b1++;
402 
403             ch_0_out_0 += a1 * b0;
404             ch_0_out_1 += a1 * b1;
405 
406             col_count--;
407         }
408         if (num_col_a % 2 && left_over_rows != 1)
409         {
410             int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
411 
412             int16_t b0 = *ip_b0++;
413             int16_t b1 = *ip_b1++;
414 
415             ch_0_out_0 += a0 * b0;
416             ch_0_out_1 += a0 * b1;
417         }
418         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
419         ch_0_out_0 += out_offset;
420         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
421         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
422         *out_0++ = (int8_t)ch_0_out_0;
423 
424         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
425         ch_0_out_1 += out_offset;
426         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
427         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
428         *out_1++ = (int8_t)ch_0_out_1;
429         out_mult++;
430         out_shift++;
431 
432         ++left_over_rows;
433     }
434 
435     out_0 += output_ch;
436 
437     /* return the new output pointer with offset */
438     return out_0;
439 }
440