1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_s8.c
22 * Description: General Matrix-multiplication function
23 *
24 * $Date: 26 October 2022
25 * $Revision: V.2.0.8
26 *
27 * Target Processor: Cortex-M cores
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnsupportfunctions.h"
31
32 /*
33 * s8 General matrix multiplication function with per-channel requantization for upto 4 column batches.
34 *
35 * Refer header file for details.
36 *
37 */
38
arm_nn_mat_mult_s8(const int8_t * input_row,const int8_t * input_col,const uint16_t output_ch,const uint16_t col_batches,const int32_t * output_shift,const int32_t * output_mult,const int32_t out_offset,const int32_t col_offset,const int32_t row_offset,const int16_t activation_min,const int16_t activation_max,const uint16_t row_len,const int32_t * const bias,int8_t * out)39 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
40 const int8_t *input_col,
41 const uint16_t output_ch,
42 const uint16_t col_batches,
43 const int32_t *output_shift,
44 const int32_t *output_mult,
45 const int32_t out_offset,
46 const int32_t col_offset,
47 const int32_t row_offset,
48 const int16_t activation_min,
49 const int16_t activation_max,
50 const uint16_t row_len,
51 const int32_t *const bias,
52 int8_t *out)
53 {
54 #if defined(ARM_MATH_MVEI)
55 (void)row_offset;
56 if (col_batches == 4)
57 {
58 for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
59 {
60 int32_t row_len_tmp = row_len;
61 const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
62 const int8_t *ip_c0 = input_col;
63 const int8_t *ip_c1 = input_col + row_len;
64 const int8_t *ip_c2 = input_col + (2 * row_len);
65 const int8_t *ip_c3 = input_col + (3 * row_len);
66
67 int32_t acc_0 = 0;
68 int32_t acc_1 = 0;
69 int32_t acc_2 = 0;
70 int32_t acc_3 = 0;
71 const int32_t row_loop_cnt = (row_len + 7) / 8;
72
73 for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
74 {
75 mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
76 const int16x8_t offset = vdupq_x_n_s16(col_offset, p);
77 row_len_tmp -= 8;
78
79 int16x8_t c0 = vldrbq_s16(ip_c0);
80 ip_c0 += 8;
81 c0 = vaddq_s16(c0, offset);
82
83 int16x8_t c1 = vldrbq_s16(ip_c1);
84 ip_c1 += 8;
85 c1 = vaddq_s16(c1, offset);
86
87 int16x8_t c2 = vldrbq_s16(ip_c2);
88 ip_c2 += 8;
89 c2 = vaddq_s16(c2, offset);
90
91 int16x8_t c3 = vldrbq_s16(ip_c3);
92 ip_c3 += 8;
93 c3 = vaddq_s16(c3, offset);
94
95 int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
96 ip_r0 += 8;
97
98 acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
99 acc_1 = vmladavaq_p_s16(acc_1, r0, c1, p);
100 acc_2 = vmladavaq_p_s16(acc_2, r0, c2, p);
101 acc_3 = vmladavaq_p_s16(acc_3, r0, c3, p);
102 }
103
104 int32x4_t res = {acc_0, acc_1, acc_2, acc_3};
105 if (bias)
106 {
107 res = vaddq_n_s32(res, bias[i_out_ch]);
108 }
109 res = arm_requantize_mve(res, output_mult[i_out_ch], output_shift[i_out_ch]);
110 res = vaddq_n_s32(res, out_offset);
111
112 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
113 res = vminq_s32(res, vdupq_n_s32(activation_max));
114
115 const uint32x4_t scatter_offset = {0, output_ch, output_ch * 2, output_ch * 3};
116 vstrbq_scatter_offset_s32(&out[i_out_ch], scatter_offset, res);
117 }
118 out += 4 * output_ch;
119 }
120 else
121 {
122 for (int i_col_batch = (col_batches & ~0x3); i_col_batch < (col_batches & 0x3); i_col_batch++)
123 {
124 for (int i_out_ch = 0; i_out_ch < output_ch; i_out_ch++)
125 {
126 int32_t row_len_tmp = row_len;
127
128 const int8_t *ip_r0 = input_row + (i_out_ch * row_len);
129 const int8_t *ip_c0 = input_col + (i_col_batch * row_len);
130 int32_t acc_0 = 0;
131 const int32_t row_loop_cnt = (row_len + 7) / 8;
132
133 for (int i_row_loop = 0; i_row_loop < row_loop_cnt; i_row_loop++)
134 {
135 const mve_pred16_t p = vctp16q((uint32_t)row_len_tmp);
136 const int16x8_t offset = vdupq_x_n_s16(col_offset, p);
137 row_len_tmp -= 8;
138
139 int16x8_t c0 = vldrbq_s16(ip_c0);
140 ip_c0 += 8;
141 c0 = vaddq_s16(c0, offset);
142
143 int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
144 ip_r0 += 8;
145 acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
146 }
147
148 if (bias)
149 {
150 acc_0 += bias[i_out_ch];
151 }
152 acc_0 = arm_nn_requantize(acc_0, output_mult[i_out_ch], output_shift[i_out_ch]);
153 acc_0 += out_offset;
154 acc_0 = MAX(acc_0, activation_min);
155 acc_0 = MIN(acc_0, activation_max);
156 out[i_out_ch] = (int8_t)acc_0;
157 }
158 out += output_ch;
159 }
160 }
161 return out;
162
163 #else
164 (void)input_row;
165 (void)input_col;
166 (void)output_ch;
167 (void)col_batches;
168 (void)output_shift;
169 (void)output_mult;
170 (void)out_offset;
171 (void)col_offset;
172 (void)row_offset;
173 (void)activation_min;
174 (void)activation_max;
175 (void)row_len;
176 (void)bias;
177 (void)out;
178 return NULL;
179 #endif
180 }
181