1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_vector_sum_s8
22  * Description:  Generic function for calculating vector sums
23  *
24  * $Date:        05 Sep 2024
25  * $Revision:    V.3.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34  *  @ingroup Public
35  */
36 
37 /**
38  * @addtogroup FC
39  * @{
40  */
41 
42 /*
43  * S8 vector sum fuction in preparation for e.g. kernel sums in fully connected and matrix multiplication layer function
44  *
45  * Refer header file for details.
46  *
47  */
arm_vector_sum_s8(int32_t * vector_sum_buf,const int32_t vector_cols,const int32_t vector_rows,const int8_t * vector_data,const int32_t lhs_offset,const int32_t rhs_offset,const int32_t * bias_data)48 arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
49                                       const int32_t vector_cols,
50                                       const int32_t vector_rows,
51                                       const int8_t *vector_data,
52                                       const int32_t lhs_offset,
53                                       const int32_t rhs_offset,
54                                       const int32_t *bias_data)
55 {
56 
57     if (bias_data)
58     {
59         memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t));
60     }
61     else
62     {
63         memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t));
64     }
65 
66     if (lhs_offset)
67     {
68 #if defined(ARM_MATH_MVEI)
69 
70         const int32_t row_loop_cnt = vector_rows / 5;
71         for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
72         {
73             const int32_t col_loop_cnt = (vector_cols + 15) / 16;
74             const int8_t *vector_0 = vector_data;
75             const int8_t *vector_1 = vector_data + vector_cols;
76             const int8_t *vector_2 = vector_data + 2 * vector_cols;
77             const int8_t *vector_3 = vector_data + 3 * vector_cols;
78             const int8_t *vector_4 = vector_data + 4 * vector_cols;
79             int32_t vector_sum_0 = 0;
80             int32_t vector_sum_1 = 0;
81             int32_t vector_sum_2 = 0;
82             int32_t vector_sum_3 = 0;
83             int32_t vector_sum_4 = 0;
84             uint32_t col_cnt = (uint32_t)vector_cols;
85             for (int i = 0; i < col_loop_cnt; i++)
86             {
87                 mve_pred16_t p = vctp8q(col_cnt);
88                 col_cnt -= 16;
89                 const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
90                 vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
91                 const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
92                 vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
93                 const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
94                 vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
95                 const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
96                 vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
97                 const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
98                 vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
99                 vector_0 += 16;
100                 vector_1 += 16;
101                 vector_2 += 16;
102                 vector_3 += 16;
103                 vector_4 += 16;
104             }
105             vector_data += 5 * vector_cols;
106 
107             if (rhs_offset)
108             {
109                 vector_sum_0 += vector_cols * rhs_offset;
110                 vector_sum_1 += vector_cols * rhs_offset;
111                 vector_sum_2 += vector_cols * rhs_offset;
112                 vector_sum_3 += vector_cols * rhs_offset;
113                 vector_sum_4 += vector_cols * rhs_offset;
114             }
115 
116             vector_sum_0 *= lhs_offset;
117             vector_sum_1 *= lhs_offset;
118             vector_sum_2 *= lhs_offset;
119             vector_sum_3 *= lhs_offset;
120             vector_sum_4 *= lhs_offset;
121 
122             vector_sum_buf[0] += vector_sum_0;
123             vector_sum_buf[1] += vector_sum_1;
124             vector_sum_buf[2] += vector_sum_2;
125             vector_sum_buf[3] += vector_sum_3;
126             vector_sum_buf[4] += vector_sum_4;
127             vector_sum_buf += 5;
128         }
129         const int32_t loop_cnt = vector_rows % 5;
130         for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
131         {
132             const int32_t col_loop_cnt = (vector_cols + 15) / 16;
133             const int8_t *vector_0 = vector_data;
134             int32_t vector_sum_0 = 0;
135             uint32_t col_cnt = (uint32_t)vector_cols;
136             for (int i = 0; i < col_loop_cnt; i++)
137             {
138                 mve_pred16_t p = vctp8q(col_cnt);
139                 col_cnt -= 16;
140                 const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
141                 vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
142                 vector_0 += 16;
143             }
144             vector_data += vector_cols;
145             if (rhs_offset)
146             {
147                 vector_sum_0 += vector_cols * rhs_offset;
148             }
149             vector_sum_0 *= lhs_offset;
150 
151             vector_sum_buf[i_row_loop_cnt] += vector_sum_0;
152         }
153 #else
154         for (int i = 0; i < vector_rows; i++)
155         {
156             int32_t sum = 0;
157             for (int j = 0; j < vector_cols; j++)
158             {
159                 sum += *vector_data++;
160             }
161             if (rhs_offset)
162             {
163                 sum += vector_cols * rhs_offset;
164             }
165             *vector_sum_buf++ += sum * lhs_offset;
166         }
167 #endif
168     }
169 
170     return (ARM_CMSIS_NN_SUCCESS);
171 }
172 
173 /**
174  * @} end of FC group
175  */
176