1 /*
2  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_batch_matmul_s8.c
22  * Description:  Batch matrix multiplication. Does not perform transposes, see header file for details.
23  *
24  * $Date:        5 Sep 2024
25  * $Revision:    V.1.0.1
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup Public
35  */
36 
37 /**
38  * @addtogroup FC
39  * @{
40  */
41 
42 /*
43  * s8 batchmatrix multiplication
44  * Refer to header file for details.
45  */
arm_batch_matmul_s8(const cmsis_nn_context * ctx,const cmsis_nn_bmm_params * bmm_params,const cmsis_nn_per_tensor_quant_params * quant_params,const cmsis_nn_dims * input_lhs_dims,const int8_t * input_lhs,const cmsis_nn_dims * input_rhs_dims,const int8_t * input_rhs,const cmsis_nn_dims * output_dims,int8_t * output)46 arm_cmsis_nn_status arm_batch_matmul_s8(const cmsis_nn_context *ctx,
47                                         const cmsis_nn_bmm_params *bmm_params,
48                                         const cmsis_nn_per_tensor_quant_params *quant_params,
49                                         const cmsis_nn_dims *input_lhs_dims,
50                                         const int8_t *input_lhs,
51                                         const cmsis_nn_dims *input_rhs_dims,
52                                         const int8_t *input_rhs,
53                                         const cmsis_nn_dims *output_dims,
54                                         int8_t *output)
55 {
56     (void)ctx;
57 #if defined(ARM_MATH_MVEI)
58     if (ctx->buf == NULL)
59     {
60         return ARM_CMSIS_NN_ARG_ERROR;
61     }
62     int32_t *vector_sum_buf = (int32_t *)ctx->buf;
63 #endif
64     const int32_t output_batch = output_dims->n;
65     const int32_t output_height = output_dims->h;
66     const int32_t lhs_rows = input_lhs_dims->w;
67     const int32_t rhs_rows = input_rhs_dims->w;
68     const int32_t rhs_cols = input_rhs_dims->c;
69 
70     const int32_t inner_lhs_diff = input_lhs_dims->h >= input_rhs_dims->h ? 0 : lhs_rows * rhs_cols;
71     const int32_t inner_rhs_diff = input_rhs_dims->h >= input_lhs_dims->h ? rhs_rows * rhs_cols : 0;
72     const int32_t outer_lhs_diff = input_lhs_dims->n >= input_rhs_dims->n
73         ? inner_lhs_diff
74         : -((lhs_rows * rhs_cols) - inner_lhs_diff) * input_lhs_dims->h;
75     const int32_t outer_rhs_diff = input_rhs_dims->n >= input_lhs_dims->n ? (rhs_rows * rhs_cols) - inner_rhs_diff
76                                                                           : -inner_rhs_diff * input_rhs_dims->h;
77 
78     for (int i_out_batch = 0; i_out_batch < output_batch; i_out_batch++)
79     {
80         for (int i_out_height = 0; i_out_height < output_height; i_out_height++)
81         {
82 
83 #if defined(ARM_MATH_MVEI)
84             arm_vector_sum_s8(vector_sum_buf,
85                               rhs_cols,
86                               rhs_rows,
87                               input_rhs,
88                               bmm_params->fc_params.input_offset,
89                               bmm_params->fc_params.filter_offset,
90                               NULL);
91 #endif
92             for (int i_lhs_rows = 0; i_lhs_rows < lhs_rows; i_lhs_rows++)
93             {
94                 arm_nn_vec_mat_mult_t_s8(input_lhs,
95                                          input_rhs,
96 #if defined(ARM_MATH_MVEI)
97                                          vector_sum_buf,
98 #else
99                                          NULL,
100 #endif
101                                          NULL,
102                                          output,
103                                          bmm_params->fc_params.input_offset,
104                                          bmm_params->fc_params.output_offset,
105                                          quant_params->multiplier,
106                                          quant_params->shift,
107                                          rhs_cols,
108                                          rhs_rows,
109                                          bmm_params->fc_params.activation.min,
110                                          bmm_params->fc_params.activation.max,
111                                          1,
112                                          bmm_params->fc_params.filter_offset);
113 
114                 input_lhs += rhs_cols;
115                 output += rhs_rows;
116             }
117             input_lhs -= inner_lhs_diff;
118             input_rhs += inner_rhs_diff;
119         }
120         input_lhs += outer_lhs_diff;
121         input_rhs += outer_rhs_diff;
122     }
123 
124     return ARM_CMSIS_NN_SUCCESS;
125 }
126 
127 /**
128  * @} end of Doxygen group
129  */
130