1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com>
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * Licensed under the Apache License, Version 2.0 (the License); you may
8 * not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 /* ----------------------------------------------------------------------
21 * Project: CMSIS NN Library
22 * Title: arm_nn_mat_mult_kernel_s16.c
23 * Description: Matrix-multiplication function for 16 bits convolution
24 *
25 * $Date: 12 April 2024
26 * $Revision: V.3.0.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup groupSupport
36 */
37
38 /**
39 * @addtogroup supportConvolution
40 * @{
41 */
42
43 /*
44 * Matrix-multiplication function for convolution with per-channel requantization.
45 *
46 * Refer header file for details.
47 *
48 */
arm_nn_mat_mult_kernel_s16(const int8_t * input_a,const int16_t * input_b,const int32_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t activation_min,const int32_t activation_max,const int32_t num_col_a,const cmsis_nn_bias_data * const bias_data,int16_t * out_0)49 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
50 const int16_t *input_b,
51 const int32_t output_ch,
52 const int32_t *out_shift,
53 const int32_t *out_mult,
54 const int32_t activation_min,
55 const int32_t activation_max,
56 const int32_t num_col_a,
57 const cmsis_nn_bias_data *const bias_data,
58 int16_t *out_0)
59 {
60 #if !defined(ARM_MATH_MVEI)
61 const int64_t *bias_s64 = (const int64_t *)bias_data->data;
62 const int32_t *bias_s32 = (const int32_t *)bias_data->data;
63 const bool is_int32_bias = bias_data->is_int32_bias;
64
65 const int32_t num_col_a_fast = is_int32_bias ? num_col_a : (num_col_a > MAX_COL_COUNT ? MAX_COL_COUNT : num_col_a);
66 const int32_t num_col_a_slow = num_col_a - MAX_COL_COUNT;
67
68 int16_t *out_1 = out_0 + output_ch;
69 int32_t row_count = output_ch / 2;
70 const int8_t *ip_a0 = input_a;
71
72 /* This loop over rows in A */
73 while (row_count)
74 {
75 /* Setup pointers for B */
76 const int16_t *ip_b0 = input_b;
77 const int16_t *ip_b1 = ip_b0 + num_col_a;
78
79 /* Align the second pointer for A */
80 const int8_t *ip_a1 = ip_a0 + num_col_a;
81
82 /* Init accumulator for channel N and N + 1 */
83 int32_t ch_0_out_0 = 0;
84 int32_t ch_0_out_1 = 0;
85 int32_t ch_1_out_0 = 0;
86 int32_t ch_1_out_1 = 0;
87
88 #if defined(ARM_MATH_DSP)
89 uint16_t col_count = num_col_a_fast / 4;
90
91 /* Accumulate over the vector */
92 while (col_count)
93 {
94 int32_t a01, a02, a11, a12;
95 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
96 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
97
98 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
99 ip_a1 = read_and_pad(ip_a1, &a11, &a12);
100
101 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
102 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
103 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
104 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
105
106 b0 = arm_nn_read_q15x2_ia(&ip_b0);
107 b1 = arm_nn_read_q15x2_ia(&ip_b1);
108
109 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
110 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
111 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
112 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
113
114 col_count--;
115 }
116 col_count = num_col_a_fast & 0x3;
117 #else
118 int32_t col_count = num_col_a_fast;
119 #endif
120
121 while (col_count)
122 {
123 int8_t a0 = *ip_a0++;
124 int16_t b0 = *ip_b0++;
125 int8_t a1 = *ip_a1++;
126 int16_t b1 = *ip_b1++;
127
128 ch_0_out_0 += a0 * b0;
129 ch_0_out_1 += a0 * b1;
130 ch_1_out_0 += a1 * b0;
131 ch_1_out_1 += a1 * b1;
132 col_count--;
133 }
134
135 if (is_int32_bias)
136 {
137 if (bias_s32)
138 {
139 ch_0_out_0 += *bias_s32;
140 ch_0_out_1 += *bias_s32++;
141 ch_1_out_0 += *bias_s32;
142 ch_1_out_1 += *bias_s32++;
143 }
144
145 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
146 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
147 out_mult++;
148 out_shift++;
149
150 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
151 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
152 *out_0++ = (int16_t)ch_0_out_0;
153
154 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
155 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
156 *out_1++ = (int16_t)ch_0_out_1;
157
158 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
159 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
160 out_mult++;
161 out_shift++;
162
163 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
164 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
165 *out_0++ = (int16_t)ch_1_out_0;
166
167 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
168 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
169 *out_1++ = (int16_t)ch_1_out_1;
170 }
171 else
172 {
173 int64_t ch_0_out_0_s64 = ch_0_out_0;
174 int64_t ch_0_out_1_s64 = ch_0_out_1;
175 int64_t ch_1_out_0_s64 = ch_1_out_0;
176 int64_t ch_1_out_1_s64 = ch_1_out_1;
177
178 if (num_col_a > MAX_COL_COUNT)
179 {
180 col_count = num_col_a_slow;
181 while (col_count)
182 {
183 int8_t a0 = *ip_a0++;
184 int16_t b0 = *ip_b0++;
185 int8_t a1 = *ip_a1++;
186 int16_t b1 = *ip_b1++;
187
188 ch_0_out_0_s64 += a0 * b0;
189 ch_0_out_1_s64 += a0 * b1;
190 ch_1_out_0_s64 += a1 * b0;
191 ch_1_out_1_s64 += a1 * b1;
192 col_count--;
193 }
194 }
195
196 if (bias_s64)
197 {
198 ch_0_out_0_s64 += *bias_s64;
199 ch_0_out_1_s64 += *bias_s64++;
200 ch_1_out_0_s64 += *bias_s64;
201 ch_1_out_1_s64 += *bias_s64++;
202 }
203
204 int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
205 ch_0_out_0 = arm_nn_requantize_s64(ch_0_out_0_s64, reduced_multiplier, *out_shift);
206 ch_0_out_1 = arm_nn_requantize_s64(ch_0_out_1_s64, reduced_multiplier, *out_shift);
207 out_mult++;
208 out_shift++;
209
210 reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
211 ch_1_out_0 = arm_nn_requantize_s64(ch_1_out_0_s64, reduced_multiplier, *out_shift);
212 ch_1_out_1 = arm_nn_requantize_s64(ch_1_out_1_s64, reduced_multiplier, *out_shift);
213
214 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
215 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
216 *out_0++ = (int16_t)ch_0_out_0;
217
218 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
219 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
220 *out_1++ = (int16_t)ch_0_out_1;
221
222 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
223 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
224 *out_0++ = (int16_t)ch_1_out_0;
225
226 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
227 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
228 *out_1++ = (int16_t)ch_1_out_1;
229
230 out_mult++;
231 out_shift++;
232 }
233
234 /* Skip row */
235 ip_a0 += num_col_a;
236 row_count--;
237 }
238
239 /* Compute the last odd numbered row if any */
240 if (output_ch & 0x1)
241 {
242 /* Setup pointers for B */
243 const int16_t *ip_b0 = input_b;
244 const int16_t *ip_b1 = ip_b0 + num_col_a;
245
246 int32_t ch_0_out_0 = 0;
247 int32_t ch_0_out_1 = 0;
248
249 #if defined(ARM_MATH_DSP)
250 uint16_t col_count = num_col_a_fast >> 2;
251 while (col_count)
252 {
253 int32_t a01, a02;
254 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
255 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
256
257 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
258
259 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
260 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
261
262 b0 = arm_nn_read_q15x2_ia(&ip_b0);
263 b1 = arm_nn_read_q15x2_ia(&ip_b1);
264 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
265 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
266
267 col_count--;
268 }
269 col_count = num_col_a & 0x3;
270 #else
271 int32_t col_count = num_col_a_fast;
272 #endif
273 while (col_count)
274 {
275 int8_t a0 = *ip_a0++;
276 int16_t b0 = *ip_b0++;
277 int16_t b1 = *ip_b1++;
278
279 ch_0_out_0 += a0 * b0;
280 ch_0_out_1 += a0 * b1;
281 col_count--;
282 }
283
284 if (is_int32_bias)
285 {
286 if (bias_s32)
287 {
288 ch_0_out_0 += *bias_s32;
289 ch_0_out_1 += *bias_s32++;
290 }
291
292 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
293 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
294 out_mult++;
295 out_shift++;
296
297 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
298 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
299 *out_0++ = (int16_t)ch_0_out_0;
300
301 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
302 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
303 *out_1++ = (int16_t)ch_0_out_1;
304 }
305 else
306 {
307 int64_t ch_0_out_0_s64 = ch_0_out_0;
308 int64_t ch_0_out_1_s64 = ch_0_out_1;
309
310 if (num_col_a > MAX_COL_COUNT)
311 {
312 col_count = num_col_a_slow;
313 while (col_count)
314 {
315 int8_t a0 = *ip_a0++;
316 int16_t b0 = *ip_b0++;
317 int16_t b1 = *ip_b1++;
318
319 ch_0_out_0_s64 += a0 * b0;
320 ch_0_out_1_s64 += a0 * b1;
321 col_count--;
322 }
323 }
324
325 if (bias_s64)
326 {
327 ch_0_out_0_s64 += *bias_s64;
328 ch_0_out_1_s64 += *bias_s64++;
329 }
330
331 int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
332 ch_0_out_0 = arm_nn_requantize_s64(ch_0_out_0_s64, reduced_multiplier, *out_shift);
333 ch_0_out_1 = arm_nn_requantize_s64(ch_0_out_1_s64, reduced_multiplier, *out_shift);
334
335 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
336 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
337 *out_0++ = (int16_t)ch_0_out_0;
338
339 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
340 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
341 *out_1++ = (int16_t)ch_0_out_1;
342 out_mult++;
343 out_shift++;
344 }
345 }
346
347 out_0 += output_ch;
348
349 /* Return the new output pointer with offset */
350 return out_0;
351 #else
352 (void)input_a;
353 (void)input_b;
354 (void)output_ch;
355 (void)out_shift;
356 (void)out_mult;
357 (void)activation_min;
358 (void)activation_max;
359 (void)num_col_a;
360 (void)bias_data;
361 (void)out_0;
362
363 return NULL;
364 #endif
365 }
366
367 /**
368 * @} end of Doxygen group
369 */
370