1 /*
2 * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_batch_matmul_s16.c
22 * Description: Batch matrix multiplication. Does not perform transposes, see header file for details.
23 *
24 * $Date: 19 June 2024
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup Public
35 */
36
37 /**
38 * @addtogroup FC
39 * @{
40 */
41
42 /*
43 * s16 batch matrix multiplication
44 * Refer to header file for details.
45 */
arm_batch_matmul_s16(const cmsis_nn_context * ctx,const cmsis_nn_bmm_params * bmm_params,const cmsis_nn_per_tensor_quant_params * quant_params,const cmsis_nn_dims * input_lhs_dims,const int16_t * input_lhs,const cmsis_nn_dims * input_rhs_dims,const int16_t * input_rhs,const cmsis_nn_dims * output_dims,int16_t * output)46 arm_cmsis_nn_status arm_batch_matmul_s16(const cmsis_nn_context *ctx,
47 const cmsis_nn_bmm_params *bmm_params,
48 const cmsis_nn_per_tensor_quant_params *quant_params,
49 const cmsis_nn_dims *input_lhs_dims,
50 const int16_t *input_lhs,
51 const cmsis_nn_dims *input_rhs_dims,
52 const int16_t *input_rhs,
53 const cmsis_nn_dims *output_dims,
54 int16_t *output)
55 {
56 (void)ctx;
57 const int32_t output_batch = output_dims->n;
58 const int32_t output_height = output_dims->h;
59 const int32_t lhs_rows = input_lhs_dims->w;
60 const int32_t rhs_rows = input_rhs_dims->w;
61 const int32_t rhs_cols = input_rhs_dims->c;
62
63 const int32_t inner_lhs_diff = input_lhs_dims->h >= input_rhs_dims->h ? 0 : lhs_rows * rhs_cols;
64 const int32_t inner_rhs_diff = input_rhs_dims->h >= input_lhs_dims->h ? rhs_rows * rhs_cols : 0;
65 const int32_t outer_lhs_diff = input_lhs_dims->n >= input_rhs_dims->n
66 ? inner_lhs_diff
67 : -((lhs_rows * rhs_cols) - inner_lhs_diff) * input_lhs_dims->h;
68 const int32_t outer_rhs_diff = input_rhs_dims->n >= input_lhs_dims->n ? (rhs_rows * rhs_cols) - inner_rhs_diff
69 : -inner_rhs_diff * input_rhs_dims->h;
70
71 const int32_t reduced_multiplier = REDUCE_MULTIPLIER(quant_params->multiplier);
72
73 for (int i_out_batch = 0; i_out_batch < output_batch; i_out_batch++)
74 {
75 for (int i_out_height = 0; i_out_height < output_height; i_out_height++)
76 {
77
78 for (int j = 0; j < lhs_rows; j++)
79 {
80 arm_nn_vec_mat_mult_t_s16_s16(input_lhs,
81 input_rhs,
82 NULL,
83 output,
84 reduced_multiplier,
85 quant_params->shift,
86 rhs_cols,
87 rhs_rows,
88 bmm_params->fc_params.activation.min,
89 bmm_params->fc_params.activation.max);
90 input_lhs += rhs_cols;
91 output += rhs_rows;
92 }
93 input_lhs -= inner_lhs_diff;
94 input_rhs += inner_rhs_diff;
95 }
96 input_lhs += outer_lhs_diff;
97 input_rhs += outer_rhs_diff;
98 }
99
100 return ARM_CMSIS_NN_SUCCESS;
101 }
102
103 /**
104 * @} end of Doxygen group
105 */
106