1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_row_offset_s8_s16.c
22 * Description: Matrix-multiplication function for grouped convolution
23 *
24 * $Date: 04 January 2024
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 /*
33 * Matrix-multiplication function for convolution with per-channel requantization, supporting an address offset between
34 * rows.
35 *
36 * Refer header file for details.
37 *
38 */
39
arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t * input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int32_t aligned_num_col_a,const int32_t * const output_bias,const int32_t row_address_offset,int8_t * out_0)40 int8_t *arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t *input_a,
41 const int16_t *input_b,
42 const uint16_t output_ch,
43 const int32_t *out_shift,
44 const int32_t *out_mult,
45 const int32_t out_offset,
46 const int16_t activation_min,
47 const int16_t activation_max,
48 const int32_t num_col_a,
49 const int32_t aligned_num_col_a,
50 const int32_t *const output_bias,
51 const int32_t row_address_offset,
52 int8_t *out_0)
53 {
54
55 #if !defined(ARM_MATH_MVEI)
56 /* set up the second output pointers */
57
58 int8_t *out_1 = out_0 + row_address_offset;
59 const int32_t *bias = output_bias;
60
61 uint16_t row_count = output_ch / 2;
62 const int8_t *ip_a0 = input_a;
63 /* this loop over rows in A */
64 while (row_count)
65 {
66 /* setup pointers for B */
67 const int16_t *ip_b0 = input_b;
68 const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
69
70 /* align the second pointer for A */
71 const int8_t *ip_a1 = ip_a0 + num_col_a;
72
73 int32_t ch_0_out_0 = 0;
74 int32_t ch_0_out_1 = 0;
75 int32_t ch_1_out_0 = 0;
76 int32_t ch_1_out_1 = 0;
77 /* Init accumulator with bias for channel N and N + 1 */
78 if (bias)
79 {
80 ch_0_out_0 = *bias;
81 ch_0_out_1 = *bias++;
82 ch_1_out_0 = *bias;
83 ch_1_out_1 = *bias++;
84 }
85
86 #if defined(ARM_MATH_DSP)
87 int32_t col_count = num_col_a / 4;
88 /* accumulate over the vector */
89 while (col_count)
90 {
91 int32_t a01, a02, a11, a12;
92 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
93 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
94
95 ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
96 ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
97
98 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
99 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
100 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
101 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
102
103 b0 = arm_nn_read_q15x2_ia(&ip_b0);
104 b1 = arm_nn_read_q15x2_ia(&ip_b1);
105
106 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
107 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
108 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
109 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
110
111 col_count--;
112 } /* while over col_count */
113
114 col_count = num_col_a & 0x3;
115
116 #else
117 int32_t col_count = num_col_a;
118 #endif
119 while (col_count)
120 {
121 int8_t a0 = *ip_a0++;
122 int16_t b0 = *ip_b0++;
123 int8_t a1 = *ip_a1++;
124 int16_t b1 = *ip_b1++;
125
126 ch_0_out_0 += a0 * b0;
127 ch_0_out_1 += a0 * b1;
128 ch_1_out_0 += a1 * b0;
129 ch_1_out_1 += a1 * b1;
130 col_count--;
131 } /* while over col_count */
132
133 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
134 ch_0_out_0 += out_offset;
135 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
136 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
137 *out_0++ = (int8_t)ch_0_out_0;
138
139 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
140 ch_0_out_1 += out_offset;
141 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
142 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
143 *out_1++ = (int8_t)ch_0_out_1;
144 out_mult++;
145 out_shift++;
146
147 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
148 ch_1_out_0 += out_offset;
149 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
150 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
151 *out_0++ = (int8_t)ch_1_out_0;
152
153 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
154 ch_1_out_1 += out_offset;
155 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
156 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
157 *out_1++ = (int8_t)ch_1_out_1;
158 out_mult++;
159 out_shift++;
160
161 /* skip row */
162 ip_a0 += num_col_a;
163 row_count--;
164 }
165
166 /* compute the last odd numbered row if any */
167 if (output_ch & 0x1)
168 {
169 /* setup pointers for B */
170 const int16_t *ip_b0 = input_b;
171 const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
172
173 int32_t ch_0_out_0 = 0;
174 int32_t ch_0_out_1 = 0;
175
176 /* load the bias */
177 if (bias)
178 {
179 ch_0_out_0 = *bias;
180 ch_0_out_1 = *bias++;
181 }
182
183 #if defined(ARM_MATH_DSP)
184 int32_t col_count = num_col_a >> 2;
185 while (col_count)
186 {
187 int32_t a01, a02;
188 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
189 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
190
191 ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
192
193 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
194 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
195
196 b0 = arm_nn_read_q15x2_ia(&ip_b0);
197 b1 = arm_nn_read_q15x2_ia(&ip_b1);
198 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
199 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
200
201 col_count--;
202 }
203 col_count = num_col_a & 0x3;
204
205 #else
206 int32_t col_count = num_col_a;
207 #endif
208 while (col_count)
209 {
210 int8_t a0 = *ip_a0++;
211 int16_t b0 = *ip_b0++;
212 int16_t b1 = *ip_b1++;
213
214 ch_0_out_0 += a0 * b0;
215 ch_0_out_1 += a0 * b1;
216 col_count--;
217 }
218
219 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
220 ch_0_out_0 += out_offset;
221 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
222 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
223 *out_0++ = (int8_t)ch_0_out_0;
224
225 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
226 ch_0_out_1 += out_offset;
227 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
228 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
229 *out_1++ = (int8_t)ch_0_out_1;
230 out_mult++;
231 out_shift++;
232 }
233
234 out_0 += 2 * row_address_offset - output_ch;
235
236 /* return the new output pointer with offset */
237 return out_0;
238 #else
239 (void)input_a;
240 (void)input_b;
241 (void)output_ch;
242 (void)out_shift;
243 (void)out_mult;
244 (void)out_offset;
245 (void)activation_min;
246 (void)activation_max;
247 (void)aligned_num_col_a, (void)num_col_a;
248 (void)output_bias;
249 (void)row_address_offset;
250 (void)out_0;
251 return NULL;
252 #endif
253 }
254