1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com>
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * Licensed under the Apache License, Version 2.0 (the License); you may
8 * not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 /* ----------------------------------------------------------------------
21 * Project: CMSIS NN Library
22 * Title: arm_nn_mat_mult_kernel_s16.c
23 * Description: Matrix-multiplication function for convolution
24 *
25 * $Date: 5 Janauray 2023
26 * $Revision: V.1.2.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup groupSupport
36 */
37
38 /**
39 * @addtogroup supportConvolution
40 * @{
41 */
42
43 /*
44 * Matrix-multiplication function for convolution with per-channel requantization.
45 *
46 * Refer header file for details.
47 *
48 */
49
arm_nn_mat_mult_kernel_s16(const int8_t * input_a,const int16_t * input_b,const int32_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int64_t * const output_bias,int16_t * out_0)50 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
51 const int16_t *input_b,
52 const int32_t output_ch,
53 const int32_t *out_shift,
54 const int32_t *out_mult,
55 const int16_t activation_min,
56 const int16_t activation_max,
57 const int32_t num_col_a,
58 const int64_t *const output_bias,
59 int16_t *out_0)
60 {
61
62 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
63 /* set up the second output pointers */
64 int16_t *out_1 = out_0 + output_ch;
65 const int64_t *bias = output_bias;
66 uint16_t row_count = output_ch / 2;
67 const int8_t *ip_a0 = input_a;
68
69 /* this loop over rows in A */
70 while (row_count)
71 {
72 /* setup pointers for B */
73 const int16_t *ip_b0 = input_b;
74 const int16_t *ip_b1 = ip_b0 + num_col_a;
75
76 /* align the second pointer for A */
77 const int8_t *ip_a1 = ip_a0 + num_col_a;
78
79 /* Init accumulator for channel N and N + 1 */
80 int32_t ch_0_out_0 = 0;
81 int32_t ch_0_out_1 = 0;
82 int32_t ch_1_out_0 = 0;
83 int32_t ch_1_out_1 = 0;
84
85 uint16_t col_count = num_col_a / 4;
86 /* accumulate over the vector */
87 while (col_count)
88 {
89 int32_t a01, a02, a11, a12;
90 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
91 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
92
93 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
94 ip_a1 = read_and_pad(ip_a1, &a11, &a12);
95
96 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
97 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
98 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
99 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
100
101 b0 = arm_nn_read_q15x2_ia(&ip_b0);
102 b1 = arm_nn_read_q15x2_ia(&ip_b1);
103
104 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
105 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
106 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
107 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
108
109 col_count--;
110 } /* while over col_count */
111 col_count = num_col_a & 0x3;
112 while (col_count)
113 {
114 int8_t a0 = *ip_a0++;
115 int16_t b0 = *ip_b0++;
116 int8_t a1 = *ip_a1++;
117 int16_t b1 = *ip_b1++;
118
119 ch_0_out_0 += a0 * b0;
120 ch_0_out_1 += a0 * b1;
121 ch_1_out_0 += a1 * b0;
122 ch_1_out_1 += a1 * b1;
123 col_count--;
124 } /* while over col_count */
125 if (bias)
126 {
127 int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
128 int64_t acc_64 = ch_0_out_0 + *bias;
129 ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
130 acc_64 = ch_0_out_1 + *bias++;
131 ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
132 out_mult++;
133 }
134 else
135 {
136 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
137 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
138 out_mult++;
139 }
140 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
141 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
142 *out_0++ = (int16_t)ch_0_out_0;
143
144 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
145 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
146 *out_1++ = (int16_t)ch_0_out_1;
147 out_shift++;
148
149 if (bias)
150 {
151 int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
152 int64_t acc_64 = ch_1_out_0 + *bias;
153 ch_1_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
154 acc_64 = ch_1_out_1 + *bias++;
155 ch_1_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
156 out_mult++;
157 }
158 else
159 {
160 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
161 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
162 out_mult++;
163 }
164 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
165 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
166 *out_0++ = (int16_t)ch_1_out_0;
167
168 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
169 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
170 *out_1++ = (int16_t)ch_1_out_1;
171 out_shift++;
172
173 /* skip row */
174 ip_a0 += num_col_a;
175 row_count--;
176 }
177
178 /* compute the last odd numbered row if any */
179 if (output_ch & 0x1)
180 {
181 /* setup pointers for B */
182 const int16_t *ip_b0 = input_b;
183 const int16_t *ip_b1 = ip_b0 + num_col_a;
184
185 int32_t ch_0_out_0 = 0;
186 int32_t ch_0_out_1 = 0;
187
188 uint16_t col_count = num_col_a >> 2;
189 while (col_count)
190 {
191 int32_t a01, a02;
192 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
193 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
194
195 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
196
197 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
198 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
199
200 b0 = arm_nn_read_q15x2_ia(&ip_b0);
201 b1 = arm_nn_read_q15x2_ia(&ip_b1);
202 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
203 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
204
205 col_count--;
206 }
207 col_count = num_col_a & 0x3;
208 while (col_count)
209 {
210 int8_t a0 = *ip_a0++;
211 int16_t b0 = *ip_b0++;
212 int16_t b1 = *ip_b1++;
213
214 ch_0_out_0 += a0 * b0;
215 ch_0_out_1 += a0 * b1;
216 col_count--;
217 }
218 if (bias)
219 {
220 int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
221 int64_t acc_64 = ch_0_out_0 + *bias;
222 ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
223 acc_64 = ch_0_out_1 + *bias++;
224 ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
225 }
226 else
227 {
228 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
229 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
230 }
231 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
232 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
233 *out_0++ = (int16_t)ch_0_out_0;
234
235 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
236 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
237 *out_1++ = (int16_t)ch_0_out_1;
238 out_mult++;
239 out_shift++;
240 }
241
242 out_0 += output_ch;
243
244 /* return the new output pointer with offset */
245 return out_0;
246 #else
247 (void)input_a;
248 (void)input_b;
249 (void)output_ch;
250 (void)out_shift;
251 (void)out_mult;
252 (void)activation_min;
253 (void)activation_max;
254 (void)num_col_a;
255 (void)output_bias;
256 (void)out_0;
257 /* To be completed */
258 return NULL;
259 #endif
260 }
261
262 /**
263 * @} end of Doxygen group
264 */