1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates
3  * <open-source-office@arm.com>
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * Licensed under the Apache License, Version 2.0 (the License); you may
8  * not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 /* ----------------------------------------------------------------------
21  * Project:      CMSIS NN Library
22  * Title:        arm_nn_mat_mult_kernel_s16.c
23  * Description:  Matrix-multiplication function for convolution
24  *
25  * $Date:        5 Janauray 2023
26  * $Revision:    V.1.2.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup groupSupport
36  */
37 
38 /**
39  * @addtogroup supportConvolution
40  * @{
41  */
42 
43 /*
44  * Matrix-multiplication function for convolution with per-channel requantization.
45  *
46  * Refer header file for details.
47  *
48  */
49 
arm_nn_mat_mult_kernel_s16(const int8_t * input_a,const int16_t * input_b,const int32_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int64_t * const output_bias,int16_t * out_0)50 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
51                                     const int16_t *input_b,
52                                     const int32_t output_ch,
53                                     const int32_t *out_shift,
54                                     const int32_t *out_mult,
55                                     const int16_t activation_min,
56                                     const int16_t activation_max,
57                                     const int32_t num_col_a,
58                                     const int64_t *const output_bias,
59                                     int16_t *out_0)
60 {
61 
62 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
63     /* set up the second output pointers */
64     int16_t *out_1 = out_0 + output_ch;
65     const int64_t *bias = output_bias;
66     uint16_t row_count = output_ch / 2;
67     const int8_t *ip_a0 = input_a;
68 
69     /* this loop over rows in A */
70     while (row_count)
71     {
72         /* setup pointers for B */
73         const int16_t *ip_b0 = input_b;
74         const int16_t *ip_b1 = ip_b0 + num_col_a;
75 
76         /* align the second pointer for A */
77         const int8_t *ip_a1 = ip_a0 + num_col_a;
78 
79         /* Init accumulator for channel N and N + 1 */
80         int32_t ch_0_out_0 = 0;
81         int32_t ch_0_out_1 = 0;
82         int32_t ch_1_out_0 = 0;
83         int32_t ch_1_out_1 = 0;
84 
85         uint16_t col_count = num_col_a / 4;
86         /* accumulate over the vector */
87         while (col_count)
88         {
89             int32_t a01, a02, a11, a12;
90             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
91             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
92 
93             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
94             ip_a1 = read_and_pad(ip_a1, &a11, &a12);
95 
96             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
97             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
98             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
99             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
100 
101             b0 = arm_nn_read_q15x2_ia(&ip_b0);
102             b1 = arm_nn_read_q15x2_ia(&ip_b1);
103 
104             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
105             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
106             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
107             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
108 
109             col_count--;
110         } /* while over col_count */
111         col_count = num_col_a & 0x3;
112         while (col_count)
113         {
114             int8_t a0 = *ip_a0++;
115             int16_t b0 = *ip_b0++;
116             int8_t a1 = *ip_a1++;
117             int16_t b1 = *ip_b1++;
118 
119             ch_0_out_0 += a0 * b0;
120             ch_0_out_1 += a0 * b1;
121             ch_1_out_0 += a1 * b0;
122             ch_1_out_1 += a1 * b1;
123             col_count--;
124         } /* while over col_count */
125         if (bias)
126         {
127             int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
128             int64_t acc_64 = ch_0_out_0 + *bias;
129             ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
130             acc_64 = ch_0_out_1 + *bias++;
131             ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
132             out_mult++;
133         }
134         else
135         {
136             ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
137             ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
138             out_mult++;
139         }
140         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
141         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
142         *out_0++ = (int16_t)ch_0_out_0;
143 
144         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
145         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
146         *out_1++ = (int16_t)ch_0_out_1;
147         out_shift++;
148 
149         if (bias)
150         {
151             int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
152             int64_t acc_64 = ch_1_out_0 + *bias;
153             ch_1_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
154             acc_64 = ch_1_out_1 + *bias++;
155             ch_1_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
156             out_mult++;
157         }
158         else
159         {
160             ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
161             ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
162             out_mult++;
163         }
164         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
165         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
166         *out_0++ = (int16_t)ch_1_out_0;
167 
168         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
169         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
170         *out_1++ = (int16_t)ch_1_out_1;
171         out_shift++;
172 
173         /* skip row */
174         ip_a0 += num_col_a;
175         row_count--;
176     }
177 
178     /* compute the last odd numbered row if any */
179     if (output_ch & 0x1)
180     {
181         /* setup pointers for B */
182         const int16_t *ip_b0 = input_b;
183         const int16_t *ip_b1 = ip_b0 + num_col_a;
184 
185         int32_t ch_0_out_0 = 0;
186         int32_t ch_0_out_1 = 0;
187 
188         uint16_t col_count = num_col_a >> 2;
189         while (col_count)
190         {
191             int32_t a01, a02;
192             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
193             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
194 
195             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
196 
197             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
198             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
199 
200             b0 = arm_nn_read_q15x2_ia(&ip_b0);
201             b1 = arm_nn_read_q15x2_ia(&ip_b1);
202             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
203             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
204 
205             col_count--;
206         }
207         col_count = num_col_a & 0x3;
208         while (col_count)
209         {
210             int8_t a0 = *ip_a0++;
211             int16_t b0 = *ip_b0++;
212             int16_t b1 = *ip_b1++;
213 
214             ch_0_out_0 += a0 * b0;
215             ch_0_out_1 += a0 * b1;
216             col_count--;
217         }
218         if (bias)
219         {
220             int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
221             int64_t acc_64 = ch_0_out_0 + *bias;
222             ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
223             acc_64 = ch_0_out_1 + *bias++;
224             ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
225         }
226         else
227         {
228             ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
229             ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
230         }
231         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
232         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
233         *out_0++ = (int16_t)ch_0_out_0;
234 
235         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
236         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
237         *out_1++ = (int16_t)ch_0_out_1;
238         out_mult++;
239         out_shift++;
240     }
241 
242     out_0 += output_ch;
243 
244     /* return the new output pointer with offset */
245     return out_0;
246 #else
247     (void)input_a;
248     (void)input_b;
249     (void)output_ch;
250     (void)out_shift;
251     (void)out_mult;
252     (void)activation_min;
253     (void)activation_max;
254     (void)num_col_a;
255     (void)output_bias;
256     (void)out_0;
257     /* To be completed */
258     return NULL;
259 #endif
260 }
261 
262 /**
263  * @} end of Doxygen group
264  */