1 /*
2  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_batch_matmul_s16.c
22  * Description:  Batch matrix multiplication. Does not perform transposes, see header file for details.
23  *
24  * $Date:        19 June 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup Public
35  */
36 
37 /**
38  * @addtogroup FC
39  * @{
40  */
41 
42 /*
43  * s16 batch matrix multiplication
44  * Refer to header file for details.
45  */
arm_batch_matmul_s16(const cmsis_nn_context * ctx,const cmsis_nn_bmm_params * bmm_params,const cmsis_nn_per_tensor_quant_params * quant_params,const cmsis_nn_dims * input_lhs_dims,const int16_t * input_lhs,const cmsis_nn_dims * input_rhs_dims,const int16_t * input_rhs,const cmsis_nn_dims * output_dims,int16_t * output)46 arm_cmsis_nn_status arm_batch_matmul_s16(const cmsis_nn_context *ctx,
47                                          const cmsis_nn_bmm_params *bmm_params,
48                                          const cmsis_nn_per_tensor_quant_params *quant_params,
49                                          const cmsis_nn_dims *input_lhs_dims,
50                                          const int16_t *input_lhs,
51                                          const cmsis_nn_dims *input_rhs_dims,
52                                          const int16_t *input_rhs,
53                                          const cmsis_nn_dims *output_dims,
54                                          int16_t *output)
55 {
56     (void)ctx;
57     const int32_t output_batch = output_dims->n;
58     const int32_t output_height = output_dims->h;
59     const int32_t lhs_rows = input_lhs_dims->w;
60     const int32_t rhs_rows = input_rhs_dims->w;
61     const int32_t rhs_cols = input_rhs_dims->c;
62 
63     const int32_t inner_lhs_diff = input_lhs_dims->h >= input_rhs_dims->h ? 0 : lhs_rows * rhs_cols;
64     const int32_t inner_rhs_diff = input_rhs_dims->h >= input_lhs_dims->h ? rhs_rows * rhs_cols : 0;
65     const int32_t outer_lhs_diff = input_lhs_dims->n >= input_rhs_dims->n
66         ? inner_lhs_diff
67         : -((lhs_rows * rhs_cols) - inner_lhs_diff) * input_lhs_dims->h;
68     const int32_t outer_rhs_diff = input_rhs_dims->n >= input_lhs_dims->n ? (rhs_rows * rhs_cols) - inner_rhs_diff
69                                                                           : -inner_rhs_diff * input_rhs_dims->h;
70 
71     const int32_t reduced_multiplier = REDUCE_MULTIPLIER(quant_params->multiplier);
72 
73     for (int i_out_batch = 0; i_out_batch < output_batch; i_out_batch++)
74     {
75         for (int i_out_height = 0; i_out_height < output_height; i_out_height++)
76         {
77 
78             for (int j = 0; j < lhs_rows; j++)
79             {
80                 arm_nn_vec_mat_mult_t_s16_s16(input_lhs,
81                                               input_rhs,
82                                               NULL,
83                                               output,
84                                               reduced_multiplier,
85                                               quant_params->shift,
86                                               rhs_cols,
87                                               rhs_rows,
88                                               bmm_params->fc_params.activation.min,
89                                               bmm_params->fc_params.activation.max);
90                 input_lhs += rhs_cols;
91                 output += rhs_rows;
92             }
93             input_lhs -= inner_lhs_diff;
94             input_rhs += inner_rhs_diff;
95         }
96         input_lhs += outer_lhs_diff;
97         input_rhs += outer_rhs_diff;
98     }
99 
100     return ARM_CMSIS_NN_SUCCESS;
101 }
102 
103 /**
104  * @} end of Doxygen group
105  */
106