1 /*
2  * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_s8_s16_reordered.c
22  * Description:  Matrix-multiplication function for convolution with reordered columns
23  *
24  * $Date:        09. October 2020
25  * $Revision:    V.1.0.3
26  *
27  * Target Processor:  Cortex-M cores
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /*
34  * Matrix-multiplication with re-ordered input and bias inputs for convolution with per-channel
35  *        requantization. The re-ordering is a consequence of sign extension is done by the SXTB16 command.
36  *
37  * Refer header file for details. This function differs from arm_nn_mat_mult_kernel_s8_s16(), in that it uses
38  *        read_and_pad_reordered() instead of arm_nn_mat_mult_kernel_s8_s16(). Investigating the cycles impact and
39  *        unifying these two functions is a potential future improvement.
40  *
41  */
42 
arm_nn_mat_mult_kernel_s8_s16_reordered(const q7_t * input_a,const q15_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const uint16_t num_col_a,const int32_t * const output_bias,q7_t * out_0)43 q7_t *arm_nn_mat_mult_kernel_s8_s16_reordered(const q7_t *input_a,
44                                               const q15_t *input_b,
45                                               const uint16_t output_ch,
46                                               const int32_t *out_shift,
47                                               const int32_t *out_mult,
48                                               const int32_t out_offset,
49                                               const int16_t activation_min,
50                                               const int16_t activation_max,
51                                               const uint16_t num_col_a,
52                                               const int32_t *const output_bias,
53                                               q7_t *out_0)
54 {
55 #if defined(ARM_MATH_DSP)
56     /* set up the second output pointers */
57     q7_t *out_1 = out_0 + output_ch;
58     const int32_t *bias = output_bias;
59 
60     uint16_t row_count = output_ch / 2;
61     const q7_t *ip_a0 = input_a;
62     /* this loop over rows in A */
63     while (row_count)
64     {
65         /* setup pointers for B */
66         const q15_t *ip_b0 = input_b;
67         const q15_t *ip_b1 = ip_b0 + num_col_a;
68 
69         /* align the second pointer for A */
70         const q7_t *ip_a1 = ip_a0 + num_col_a;
71 
72         /* Init accumulator with bias for channel N and N + 1 */
73         q31_t ch_0_out_0 = *bias;
74         q31_t ch_0_out_1 = *bias++;
75         q31_t ch_1_out_0 = *bias;
76         q31_t ch_1_out_1 = *bias++;
77 
78         uint16_t col_count = num_col_a / 4;
79         /* accumulate over the vector */
80         while (col_count)
81         {
82             q31_t a01, a02, a11, a12;
83             q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
84             q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
85 
86             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
87             ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
88 
89             ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
90             ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
91             ch_1_out_0 = __SMLAD(a11, b0, ch_1_out_0);
92             ch_1_out_1 = __SMLAD(a11, b1, ch_1_out_1);
93 
94             b0 = arm_nn_read_q15x2_ia(&ip_b0);
95             b1 = arm_nn_read_q15x2_ia(&ip_b1);
96 
97             ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
98             ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
99             ch_1_out_0 = __SMLAD(a12, b0, ch_1_out_0);
100             ch_1_out_1 = __SMLAD(a12, b1, ch_1_out_1);
101 
102             col_count--;
103         } /* while over col_count */
104 
105         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
106         ch_0_out_0 += out_offset;
107         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
108         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
109         *out_0++ = (q7_t)ch_0_out_0;
110 
111         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
112         ch_0_out_1 += out_offset;
113         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
114         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
115         *out_1++ = (q7_t)ch_0_out_1;
116         out_mult++;
117         out_shift++;
118 
119         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
120         ch_1_out_0 += out_offset;
121         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
122         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
123         *out_0++ = (q7_t)ch_1_out_0;
124 
125         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
126         ch_1_out_1 += out_offset;
127         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
128         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
129         *out_1++ = (q7_t)ch_1_out_1;
130         out_mult++;
131         out_shift++;
132 
133         /* skip row */
134         ip_a0 += num_col_a;
135         row_count--;
136     }
137 
138     if (output_ch & 1)
139     {
140         /* setup pointers for B */
141         const q15_t *ip_b0 = input_b;
142         const q15_t *ip_b1 = ip_b0 + num_col_a;
143 
144         /* Init accumulator with bias for channel N + 1 */
145         q31_t ch_0_out_0 = *bias;
146         q31_t ch_0_out_1 = ch_0_out_0;
147 
148         int32_t col_count = num_col_a / 4;
149         while (col_count)
150         {
151             q31_t a01, a02;
152             q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
153             q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
154 
155             ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
156 
157             ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
158             ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
159 
160             b0 = arm_nn_read_q15x2_ia(&ip_b0);
161             b1 = arm_nn_read_q15x2_ia(&ip_b1);
162 
163             ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
164             ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
165 
166             col_count--;
167         } /* while over col_count */
168 
169         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
170         ch_0_out_0 += out_offset;
171         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
172         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
173         *out_0++ = (q7_t)ch_0_out_0;
174 
175         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
176         ch_0_out_1 += out_offset;
177         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
178         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
179         *out_1++ = (q7_t)ch_0_out_1;
180     }
181 
182     out_0 += output_ch;
183 
184     /* return the new output pointer with offset */
185     return out_0;
186 #else
187     (void)input_a;
188     (void)input_b;
189     (void)output_ch;
190     (void)out_shift;
191     (void)out_mult;
192     (void)out_offset;
193     (void)activation_min;
194     (void)activation_max;
195     (void)num_col_a;
196     (void)output_bias;
197     (void)out_0;
198     /* To be completed */
199     return NULL;
200 #endif
201 }
202