1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_vector_sum_s8
22 * Description: Generic function for calculating vector sums
23 *
24 * $Date: 05 Sep 2024
25 * $Revision: V.3.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34 * @ingroup Public
35 */
36
37 /**
38 * @addtogroup FC
39 * @{
40 */
41
42 /*
43 * S8 vector sum fuction in preparation for e.g. kernel sums in fully connected and matrix multiplication layer function
44 *
45 * Refer header file for details.
46 *
47 */
arm_vector_sum_s8(int32_t * vector_sum_buf,const int32_t vector_cols,const int32_t vector_rows,const int8_t * vector_data,const int32_t lhs_offset,const int32_t rhs_offset,const int32_t * bias_data)48 arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
49 const int32_t vector_cols,
50 const int32_t vector_rows,
51 const int8_t *vector_data,
52 const int32_t lhs_offset,
53 const int32_t rhs_offset,
54 const int32_t *bias_data)
55 {
56
57 if (bias_data)
58 {
59 memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t));
60 }
61 else
62 {
63 memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t));
64 }
65
66 if (lhs_offset)
67 {
68 #if defined(ARM_MATH_MVEI)
69
70 const int32_t row_loop_cnt = vector_rows / 5;
71 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
72 {
73 const int32_t col_loop_cnt = (vector_cols + 15) / 16;
74 const int8_t *vector_0 = vector_data;
75 const int8_t *vector_1 = vector_data + vector_cols;
76 const int8_t *vector_2 = vector_data + 2 * vector_cols;
77 const int8_t *vector_3 = vector_data + 3 * vector_cols;
78 const int8_t *vector_4 = vector_data + 4 * vector_cols;
79 int32_t vector_sum_0 = 0;
80 int32_t vector_sum_1 = 0;
81 int32_t vector_sum_2 = 0;
82 int32_t vector_sum_3 = 0;
83 int32_t vector_sum_4 = 0;
84 uint32_t col_cnt = (uint32_t)vector_cols;
85 for (int i = 0; i < col_loop_cnt; i++)
86 {
87 mve_pred16_t p = vctp8q(col_cnt);
88 col_cnt -= 16;
89 const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
90 vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
91 const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
92 vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
93 const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
94 vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
95 const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
96 vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
97 const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
98 vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
99 vector_0 += 16;
100 vector_1 += 16;
101 vector_2 += 16;
102 vector_3 += 16;
103 vector_4 += 16;
104 }
105 vector_data += 5 * vector_cols;
106
107 if (rhs_offset)
108 {
109 vector_sum_0 += vector_cols * rhs_offset;
110 vector_sum_1 += vector_cols * rhs_offset;
111 vector_sum_2 += vector_cols * rhs_offset;
112 vector_sum_3 += vector_cols * rhs_offset;
113 vector_sum_4 += vector_cols * rhs_offset;
114 }
115
116 vector_sum_0 *= lhs_offset;
117 vector_sum_1 *= lhs_offset;
118 vector_sum_2 *= lhs_offset;
119 vector_sum_3 *= lhs_offset;
120 vector_sum_4 *= lhs_offset;
121
122 vector_sum_buf[0] += vector_sum_0;
123 vector_sum_buf[1] += vector_sum_1;
124 vector_sum_buf[2] += vector_sum_2;
125 vector_sum_buf[3] += vector_sum_3;
126 vector_sum_buf[4] += vector_sum_4;
127 vector_sum_buf += 5;
128 }
129 const int32_t loop_cnt = vector_rows % 5;
130 for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
131 {
132 const int32_t col_loop_cnt = (vector_cols + 15) / 16;
133 const int8_t *vector_0 = vector_data;
134 int32_t vector_sum_0 = 0;
135 uint32_t col_cnt = (uint32_t)vector_cols;
136 for (int i = 0; i < col_loop_cnt; i++)
137 {
138 mve_pred16_t p = vctp8q(col_cnt);
139 col_cnt -= 16;
140 const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
141 vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
142 vector_0 += 16;
143 }
144 vector_data += vector_cols;
145 if (rhs_offset)
146 {
147 vector_sum_0 += vector_cols * rhs_offset;
148 }
149 vector_sum_0 *= lhs_offset;
150
151 vector_sum_buf[i_row_loop_cnt] += vector_sum_0;
152 }
153 #else
154 for (int i = 0; i < vector_rows; i++)
155 {
156 int32_t sum = 0;
157 for (int j = 0; j < vector_cols; j++)
158 {
159 sum += *vector_data++;
160 }
161 if (rhs_offset)
162 {
163 sum += vector_cols * rhs_offset;
164 }
165 *vector_sum_buf++ += sum * lhs_offset;
166 }
167 #endif
168 }
169
170 return (ARM_CMSIS_NN_SUCCESS);
171 }
172
173 /**
174 * @} end of FC group
175 */
176