1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_s8_s16.c
22 * Description: Matrix-multiplication function for convolution
23 *
24 * $Date: 29 May 2023
25 * $Revision: V.2.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32
33 /*
34 * Matrix-multiplication function for convolution with per-channel requantization.
35 *
36 * Refer header file for details.
37 *
38 */
39
arm_nn_mat_mult_kernel_s8_s16(const int8_t * input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int32_t aligned_num_col_a,const int32_t * const output_bias,int8_t * out_0)40 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
41 const int16_t *input_b,
42 const uint16_t output_ch,
43 const int32_t *out_shift,
44 const int32_t *out_mult,
45 const int32_t out_offset,
46 const int16_t activation_min,
47 const int16_t activation_max,
48 const int32_t num_col_a,
49 const int32_t aligned_num_col_a,
50 const int32_t *const output_bias,
51 int8_t *out_0)
52 {
53 #if !defined(ARM_MATH_MVEI)
54 /* set up the second output pointers */
55 int8_t *out_1 = out_0 + output_ch;
56 const int32_t *bias = output_bias;
57
58 uint16_t row_count = output_ch / 2;
59 const int8_t *ip_a0 = input_a;
60 /* this loop over rows in A */
61 while (row_count)
62 {
63 /* setup pointers for B */
64 const int16_t *ip_b0 = input_b;
65 const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
66
67 /* align the second pointer for A */
68 const int8_t *ip_a1 = ip_a0 + num_col_a;
69
70 int32_t ch_0_out_0 = 0;
71 int32_t ch_0_out_1 = 0;
72 int32_t ch_1_out_0 = 0;
73 int32_t ch_1_out_1 = 0;
74 /* Init accumulator with bias for channel N and N + 1 */
75 if (bias)
76 {
77 ch_0_out_0 = *bias;
78 ch_0_out_1 = *bias++;
79 ch_1_out_0 = *bias;
80 ch_1_out_1 = *bias++;
81 }
82
83 #if defined(ARM_MATH_DSP)
84 int32_t col_count = num_col_a / 4;
85 /* accumulate over the vector */
86 while (col_count)
87 {
88 int32_t a01, a02, a11, a12;
89 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
90 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
91
92 ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
93 ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
94
95 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
96 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
97 ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
98 ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
99
100 b0 = arm_nn_read_q15x2_ia(&ip_b0);
101 b1 = arm_nn_read_q15x2_ia(&ip_b1);
102
103 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
104 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
105 ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
106 ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
107
108 col_count--;
109 } /* while over col_count */
110 col_count = num_col_a & 0x3;
111 #else
112 int32_t col_count = num_col_a;
113 #endif
114 while (col_count)
115 {
116 int8_t a0 = *ip_a0++;
117 int16_t b0 = *ip_b0++;
118 int8_t a1 = *ip_a1++;
119 int16_t b1 = *ip_b1++;
120
121 ch_0_out_0 += a0 * b0;
122 ch_0_out_1 += a0 * b1;
123 ch_1_out_0 += a1 * b0;
124 ch_1_out_1 += a1 * b1;
125 col_count--;
126 } /* while over col_count */
127
128 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
129 ch_0_out_0 += out_offset;
130 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
131 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
132 *out_0++ = (int8_t)ch_0_out_0;
133
134 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
135 ch_0_out_1 += out_offset;
136 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
137 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
138 *out_1++ = (int8_t)ch_0_out_1;
139 out_mult++;
140 out_shift++;
141
142 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
143 ch_1_out_0 += out_offset;
144 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
145 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
146 *out_0++ = (int8_t)ch_1_out_0;
147
148 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
149 ch_1_out_1 += out_offset;
150 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
151 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
152 *out_1++ = (int8_t)ch_1_out_1;
153 out_mult++;
154 out_shift++;
155
156 /* skip row */
157 ip_a0 += num_col_a;
158 row_count--;
159 }
160
161 /* compute the last odd numbered row if any */
162 if (output_ch & 0x1)
163 {
164 /* setup pointers for B */
165 const int16_t *ip_b0 = input_b;
166 const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
167
168 int32_t ch_0_out_0 = 0;
169 int32_t ch_0_out_1 = 0;
170
171 /* load the bias */
172 if (bias)
173 {
174 ch_0_out_0 = *bias;
175 ch_0_out_1 = *bias++;
176 }
177
178 #if defined(ARM_MATH_DSP)
179 int32_t col_count = num_col_a >> 2;
180 while (col_count)
181 {
182 int32_t a01, a02;
183 int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
184 int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
185
186 ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
187
188 ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
189 ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
190
191 b0 = arm_nn_read_q15x2_ia(&ip_b0);
192 b1 = arm_nn_read_q15x2_ia(&ip_b1);
193 ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
194 ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
195
196 col_count--;
197 }
198 col_count = num_col_a & 0x3;
199 #else
200 int32_t col_count = num_col_a;
201 #endif
202 while (col_count)
203 {
204 int8_t a0 = *ip_a0++;
205 int16_t b0 = *ip_b0++;
206 int16_t b1 = *ip_b1++;
207
208 ch_0_out_0 += a0 * b0;
209 ch_0_out_1 += a0 * b1;
210 col_count--;
211 }
212 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
213 ch_0_out_0 += out_offset;
214 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
215 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
216 *out_0++ = (int8_t)ch_0_out_0;
217
218 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
219 ch_0_out_1 += out_offset;
220 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
221 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
222 *out_1++ = (int8_t)ch_0_out_1;
223 out_mult++;
224 out_shift++;
225 }
226
227 out_0 += output_ch;
228
229 /* return the new output pointer with offset */
230 return out_0;
231 #else
232 (void)input_a;
233 (void)input_b;
234 (void)output_ch;
235 (void)out_shift;
236 (void)out_mult;
237 (void)out_offset;
238 (void)activation_min;
239 (void)activation_max;
240 (void)aligned_num_col_a, (void)num_col_a;
241 (void)output_bias;
242 (void)out_0;
243 /* To be completed */
244 return NULL;
245 #endif
246 }
247