1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_row_offset_s8_s16.c
22  * Description:  Matrix-multiplication function for grouped convolution
23  *
24  * $Date:        04 January 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 /*
33  * Matrix-multiplication function for convolution with per-channel requantization, supporting an address offset between
34  * rows.
35  *
36  * Refer header file for details.
37  *
38  */
39 
arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t * input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int32_t aligned_num_col_a,const int32_t * const output_bias,const int32_t row_address_offset,int8_t * out_0)40 int8_t *arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t *input_a,
41                                                  const int16_t *input_b,
42                                                  const uint16_t output_ch,
43                                                  const int32_t *out_shift,
44                                                  const int32_t *out_mult,
45                                                  const int32_t out_offset,
46                                                  const int16_t activation_min,
47                                                  const int16_t activation_max,
48                                                  const int32_t num_col_a,
49                                                  const int32_t aligned_num_col_a,
50                                                  const int32_t *const output_bias,
51                                                  const int32_t row_address_offset,
52                                                  int8_t *out_0)
53 {
54 
55 #if !defined(ARM_MATH_MVEI)
56     /* set up the second output pointers */
57 
58     int8_t *out_1 = out_0 + row_address_offset;
59     const int32_t *bias = output_bias;
60 
61     uint16_t row_count = output_ch / 2;
62     const int8_t *ip_a0 = input_a;
63     /* this loop over rows in A */
64     while (row_count)
65     {
66         /* setup pointers for B */
67         const int16_t *ip_b0 = input_b;
68         const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
69 
70         /* align the second pointer for A */
71         const int8_t *ip_a1 = ip_a0 + num_col_a;
72 
73         int32_t ch_0_out_0 = 0;
74         int32_t ch_0_out_1 = 0;
75         int32_t ch_1_out_0 = 0;
76         int32_t ch_1_out_1 = 0;
77         /* Init accumulator with bias for channel N and N + 1 */
78         if (bias)
79         {
80             ch_0_out_0 = *bias;
81             ch_0_out_1 = *bias++;
82             ch_1_out_0 = *bias;
83             ch_1_out_1 = *bias++;
84         }
85 
86     #if defined(ARM_MATH_DSP)
87         int32_t col_count = num_col_a / 4;
88         /* accumulate over the vector */
89         while (col_count)
90         {
91             int32_t a01, a02, a11, a12;
92             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
93             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
94 
95             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
96             ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
97 
98             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
99             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
100             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
101             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
102 
103             b0 = arm_nn_read_q15x2_ia(&ip_b0);
104             b1 = arm_nn_read_q15x2_ia(&ip_b1);
105 
106             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
107             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
108             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
109             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
110 
111             col_count--;
112         } /* while over col_count */
113 
114         col_count = num_col_a & 0x3;
115 
116     #else
117         int32_t col_count = num_col_a;
118     #endif
119         while (col_count)
120         {
121             int8_t a0 = *ip_a0++;
122             int16_t b0 = *ip_b0++;
123             int8_t a1 = *ip_a1++;
124             int16_t b1 = *ip_b1++;
125 
126             ch_0_out_0 += a0 * b0;
127             ch_0_out_1 += a0 * b1;
128             ch_1_out_0 += a1 * b0;
129             ch_1_out_1 += a1 * b1;
130             col_count--;
131         } /* while over col_count */
132 
133         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
134         ch_0_out_0 += out_offset;
135         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
136         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
137         *out_0++ = (int8_t)ch_0_out_0;
138 
139         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
140         ch_0_out_1 += out_offset;
141         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
142         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
143         *out_1++ = (int8_t)ch_0_out_1;
144         out_mult++;
145         out_shift++;
146 
147         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
148         ch_1_out_0 += out_offset;
149         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
150         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
151         *out_0++ = (int8_t)ch_1_out_0;
152 
153         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
154         ch_1_out_1 += out_offset;
155         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
156         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
157         *out_1++ = (int8_t)ch_1_out_1;
158         out_mult++;
159         out_shift++;
160 
161         /* skip row */
162         ip_a0 += num_col_a;
163         row_count--;
164     }
165 
166     /* compute the last odd numbered row if any */
167     if (output_ch & 0x1)
168     {
169         /* setup pointers for B */
170         const int16_t *ip_b0 = input_b;
171         const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
172 
173         int32_t ch_0_out_0 = 0;
174         int32_t ch_0_out_1 = 0;
175 
176         /* load the bias */
177         if (bias)
178         {
179             ch_0_out_0 = *bias;
180             ch_0_out_1 = *bias++;
181         }
182 
183     #if defined(ARM_MATH_DSP)
184         int32_t col_count = num_col_a >> 2;
185         while (col_count)
186         {
187             int32_t a01, a02;
188             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
189             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
190 
191             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
192 
193             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
194             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
195 
196             b0 = arm_nn_read_q15x2_ia(&ip_b0);
197             b1 = arm_nn_read_q15x2_ia(&ip_b1);
198             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
199             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
200 
201             col_count--;
202         }
203         col_count = num_col_a & 0x3;
204 
205     #else
206         int32_t col_count = num_col_a;
207     #endif
208         while (col_count)
209         {
210             int8_t a0 = *ip_a0++;
211             int16_t b0 = *ip_b0++;
212             int16_t b1 = *ip_b1++;
213 
214             ch_0_out_0 += a0 * b0;
215             ch_0_out_1 += a0 * b1;
216             col_count--;
217         }
218 
219         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
220         ch_0_out_0 += out_offset;
221         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
222         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
223         *out_0++ = (int8_t)ch_0_out_0;
224 
225         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
226         ch_0_out_1 += out_offset;
227         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
228         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
229         *out_1++ = (int8_t)ch_0_out_1;
230         out_mult++;
231         out_shift++;
232     }
233 
234     out_0 += 2 * row_address_offset - output_ch;
235 
236     /* return the new output pointer with offset */
237     return out_0;
238 #else
239     (void)input_a;
240     (void)input_b;
241     (void)output_ch;
242     (void)out_shift;
243     (void)out_mult;
244     (void)out_offset;
245     (void)activation_min;
246     (void)activation_max;
247     (void)aligned_num_col_a, (void)num_col_a;
248     (void)output_bias;
249     (void)row_address_offset;
250     (void)out_0;
251     return NULL;
252 #endif
253 }
254