1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_s8.c
22  * Description:  General Matrix-multiplication function
23  *
24  * $Date:        26 October 2022
25  * $Revision:    V.2.0.8
26  *
27  * Target Processor:  Cortex-M cores
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnsupportfunctions.h"
31 
32 /*
33  * s8 General matrix multiplication function with per-channel requantization for upto 4 column batches.
34  *
35  * Refer header file for details.
36  *
37  */
38 
arm_nn_mat_mult_s8(const int8_t * input_row,const int8_t * input_col,const uint16_t output_ch,const uint16_t col_batches,const int32_t * output_shift,const int32_t * output_mult,const int32_t out_offset,const int32_t col_offset,const int32_t row_offset,const int16_t activation_min,const int16_t activation_max,const uint16_t row_len,const int32_t * const bias,int8_t * out)39 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
40                            const int8_t *input_col,
41                            const uint16_t output_ch,
42                            const uint16_t col_batches,
43                            const int32_t *output_shift,
44                            const int32_t *output_mult,
45                            const int32_t out_offset,
46                            const int32_t col_offset,
47                            const int32_t row_offset,
48                            const int16_t activation_min,
49                            const int16_t activation_max,
50                            const uint16_t row_len,
51                            const int32_t *const bias,
52                            int8_t *out)
53 {
54 #if defined(ARM_MATH_MVEI)
55     (void)row_offset;
56     if (col_batches == 4)
57     {
58         for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
59         {
60             int32_t row_len_tmp = row_len;
61             const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
62             const int8_t *ip_c0 = input_col;
63             const int8_t *ip_c1 = input_col + row_len;
64             const int8_t *ip_c2 = input_col + (2 * row_len);
65             const int8_t *ip_c3 = input_col + (3 * row_len);
66 
67             int32_t acc_0 = 0;
68             int32_t acc_1 = 0;
69             int32_t acc_2 = 0;
70             int32_t acc_3 = 0;
71             const int32_t row_loop_cnt = (row_len + 7) / 8;
72 
73             for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
74             {
75                 mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
76                 const int16x8_t offset = vdupq_x_n_s16(col_offset, p);
77                 row_len_tmp -= 8;
78 
79                 int16x8_t c0 = vldrbq_s16(ip_c0);
80                 ip_c0 += 8;
81                 c0 = vaddq_s16(c0, offset);
82 
83                 int16x8_t c1 = vldrbq_s16(ip_c1);
84                 ip_c1 += 8;
85                 c1 = vaddq_s16(c1, offset);
86 
87                 int16x8_t c2 = vldrbq_s16(ip_c2);
88                 ip_c2 += 8;
89                 c2 = vaddq_s16(c2, offset);
90 
91                 int16x8_t c3 = vldrbq_s16(ip_c3);
92                 ip_c3 += 8;
93                 c3 = vaddq_s16(c3, offset);
94 
95                 int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
96                 ip_r0 += 8;
97 
98                 acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
99                 acc_1 = vmladavaq_p_s16(acc_1, r0, c1, p);
100                 acc_2 = vmladavaq_p_s16(acc_2, r0, c2, p);
101                 acc_3 = vmladavaq_p_s16(acc_3, r0, c3, p);
102             }
103 
104             int32x4_t res = {acc_0, acc_1, acc_2, acc_3};
105             if (bias)
106             {
107                 res = vaddq_n_s32(res, bias[i_out_ch]);
108             }
109             res = arm_requantize_mve(res, output_mult[i_out_ch], output_shift[i_out_ch]);
110             res = vaddq_n_s32(res, out_offset);
111 
112             res = vmaxq_s32(res, vdupq_n_s32(activation_min));
113             res = vminq_s32(res, vdupq_n_s32(activation_max));
114 
115             const uint32x4_t scatter_offset = {0, output_ch, output_ch * 2, output_ch * 3};
116             vstrbq_scatter_offset_s32(&out[i_out_ch], scatter_offset, res);
117         }
118         out += 4 * output_ch;
119     }
120     else
121     {
122         for (int i_col_batch = (col_batches & ~0x3); i_col_batch < (col_batches & 0x3); i_col_batch++)
123         {
124             for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
125             {
126                 int32_t row_len_tmp = row_len;
127 
128                 const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
129                 const int8_t *ip_c0 = input_col + (i_col_batch * row_len);
130                 int32_t acc_0 = 0;
131                 const int32_t row_loop_cnt = (row_len + 7) / 8;
132 
133                 for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
134                 {
135                     const mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
136                     const int16x8_t offset = vdupq_x_n_s16(col_offset, p);
137                     row_len_tmp -= 8;
138 
139                     int16x8_t c0 = vldrbq_s16(ip_c0);
140                     ip_c0 += 8;
141                     c0 = vaddq_s16(c0, offset);
142 
143                     int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
144                     ip_r0 += 8;
145                     acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
146                 }
147 
148                 if (bias)
149                 {
150                     acc_0 += bias[i_out_ch];
151                 }
152                 acc_0 = arm_nn_requantize(acc_0, output_mult[i_out_ch], output_shift[i_out_ch]);
153                 acc_0 += out_offset;
154                 acc_0 = MAX(acc_0, activation_min);
155                 acc_0 = MIN(acc_0, activation_max);
156                 out[i_out_ch] = (int8_t)acc_0;
157             }
158             out += output_ch;
159         }
160     }
161     return out;
162 
163 #else
164     (void)input_row;
165     (void)input_col;
166     (void)output_ch;
167     (void)col_batches;
168     (void)output_shift;
169     (void)output_mult;
170     (void)out_offset;
171     (void)col_offset;
172     (void)row_offset;
173     (void)activation_min;
174     (void)activation_max;
175     (void)row_len;
176     (void)bias;
177     (void)out;
178     return NULL;
179 #endif
180 }
181