1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_s8_s16.c
22  * Description:  Matrix-multiplication function for convolution
23  *
24  * $Date:        29 May 2023
25  * $Revision:    V.2.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /*
34  * Matrix-multiplication function for convolution with per-channel requantization.
35  *
36  * Refer header file for details.
37  *
38  */
39 
arm_nn_mat_mult_kernel_s8_s16(const int8_t * input_a,const int16_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const int32_t num_col_a,const int32_t aligned_num_col_a,const int32_t * const output_bias,int8_t * out_0)40 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
41                                       const int16_t *input_b,
42                                       const uint16_t output_ch,
43                                       const int32_t *out_shift,
44                                       const int32_t *out_mult,
45                                       const int32_t out_offset,
46                                       const int16_t activation_min,
47                                       const int16_t activation_max,
48                                       const int32_t num_col_a,
49                                       const int32_t aligned_num_col_a,
50                                       const int32_t *const output_bias,
51                                       int8_t *out_0)
52 {
53 #if !defined(ARM_MATH_MVEI)
54     /* set up the second output pointers */
55     int8_t *out_1 = out_0 + output_ch;
56     const int32_t *bias = output_bias;
57 
58     uint16_t row_count = output_ch / 2;
59     const int8_t *ip_a0 = input_a;
60     /* this loop over rows in A */
61     while (row_count)
62     {
63         /* setup pointers for B */
64         const int16_t *ip_b0 = input_b;
65         const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
66 
67         /* align the second pointer for A */
68         const int8_t *ip_a1 = ip_a0 + num_col_a;
69 
70         int32_t ch_0_out_0 = 0;
71         int32_t ch_0_out_1 = 0;
72         int32_t ch_1_out_0 = 0;
73         int32_t ch_1_out_1 = 0;
74         /* Init accumulator with bias for channel N and N + 1 */
75         if (bias)
76         {
77             ch_0_out_0 = *bias;
78             ch_0_out_1 = *bias++;
79             ch_1_out_0 = *bias;
80             ch_1_out_1 = *bias++;
81         }
82 
83     #if defined(ARM_MATH_DSP)
84         int32_t col_count = num_col_a / 4;
85         /* accumulate over the vector */
86         while (col_count)
87         {
88             int32_t a01, a02, a11, a12;
89             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
90             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
91 
92             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
93             ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
94 
95             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
96             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
97             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
98             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
99 
100             b0 = arm_nn_read_q15x2_ia(&ip_b0);
101             b1 = arm_nn_read_q15x2_ia(&ip_b1);
102 
103             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
104             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
105             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
106             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
107 
108             col_count--;
109         } /* while over col_count */
110         col_count = num_col_a & 0x3;
111     #else
112         int32_t col_count = num_col_a;
113     #endif
114         while (col_count)
115         {
116             int8_t a0 = *ip_a0++;
117             int16_t b0 = *ip_b0++;
118             int8_t a1 = *ip_a1++;
119             int16_t b1 = *ip_b1++;
120 
121             ch_0_out_0 += a0 * b0;
122             ch_0_out_1 += a0 * b1;
123             ch_1_out_0 += a1 * b0;
124             ch_1_out_1 += a1 * b1;
125             col_count--;
126         } /* while over col_count */
127 
128         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
129         ch_0_out_0 += out_offset;
130         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
131         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
132         *out_0++ = (int8_t)ch_0_out_0;
133 
134         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
135         ch_0_out_1 += out_offset;
136         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
137         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
138         *out_1++ = (int8_t)ch_0_out_1;
139         out_mult++;
140         out_shift++;
141 
142         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
143         ch_1_out_0 += out_offset;
144         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
145         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
146         *out_0++ = (int8_t)ch_1_out_0;
147 
148         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
149         ch_1_out_1 += out_offset;
150         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
151         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
152         *out_1++ = (int8_t)ch_1_out_1;
153         out_mult++;
154         out_shift++;
155 
156         /* skip row */
157         ip_a0 += num_col_a;
158         row_count--;
159     }
160 
161     /* compute the last odd numbered row if any */
162     if (output_ch & 0x1)
163     {
164         /* setup pointers for B */
165         const int16_t *ip_b0 = input_b;
166         const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
167 
168         int32_t ch_0_out_0 = 0;
169         int32_t ch_0_out_1 = 0;
170 
171         /* load the bias */
172         if (bias)
173         {
174             ch_0_out_0 = *bias;
175             ch_0_out_1 = *bias++;
176         }
177 
178     #if defined(ARM_MATH_DSP)
179         int32_t col_count = num_col_a >> 2;
180         while (col_count)
181         {
182             int32_t a01, a02;
183             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
184             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
185 
186             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
187 
188             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
189             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
190 
191             b0 = arm_nn_read_q15x2_ia(&ip_b0);
192             b1 = arm_nn_read_q15x2_ia(&ip_b1);
193             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
194             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
195 
196             col_count--;
197         }
198         col_count = num_col_a & 0x3;
199     #else
200         int32_t col_count = num_col_a;
201     #endif
202         while (col_count)
203         {
204             int8_t a0 = *ip_a0++;
205             int16_t b0 = *ip_b0++;
206             int16_t b1 = *ip_b1++;
207 
208             ch_0_out_0 += a0 * b0;
209             ch_0_out_1 += a0 * b1;
210             col_count--;
211         }
212         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
213         ch_0_out_0 += out_offset;
214         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
215         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
216         *out_0++ = (int8_t)ch_0_out_0;
217 
218         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
219         ch_0_out_1 += out_offset;
220         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
221         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
222         *out_1++ = (int8_t)ch_0_out_1;
223         out_mult++;
224         out_shift++;
225     }
226 
227     out_0 += output_ch;
228 
229     /* return the new output pointer with offset */
230     return out_0;
231 #else
232     (void)input_a;
233     (void)input_b;
234     (void)output_ch;
235     (void)out_shift;
236     (void)out_mult;
237     (void)out_offset;
238     (void)activation_min;
239     (void)activation_max;
240     (void)aligned_num_col_a, (void)num_col_a;
241     (void)output_bias;
242     (void)out_0;
243     /* To be completed */
244     return NULL;
245 #endif
246 }
247