1 /*
2 * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_batch_matmul_s8.c
22 * Description: Batch matrix multiplication. Does not perform transposes, see header file for details.
23 *
24 * $Date: 5 Sep 2024
25 * $Revision: V.1.0.1
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup Public
35 */
36
37 /**
38 * @addtogroup FC
39 * @{
40 */
41
42 /*
43 * s8 batchmatrix multiplication
44 * Refer to header file for details.
45 */
arm_batch_matmul_s8(const cmsis_nn_context * ctx,const cmsis_nn_bmm_params * bmm_params,const cmsis_nn_per_tensor_quant_params * quant_params,const cmsis_nn_dims * input_lhs_dims,const int8_t * input_lhs,const cmsis_nn_dims * input_rhs_dims,const int8_t * input_rhs,const cmsis_nn_dims * output_dims,int8_t * output)46 arm_cmsis_nn_status arm_batch_matmul_s8(const cmsis_nn_context *ctx,
47 const cmsis_nn_bmm_params *bmm_params,
48 const cmsis_nn_per_tensor_quant_params *quant_params,
49 const cmsis_nn_dims *input_lhs_dims,
50 const int8_t *input_lhs,
51 const cmsis_nn_dims *input_rhs_dims,
52 const int8_t *input_rhs,
53 const cmsis_nn_dims *output_dims,
54 int8_t *output)
55 {
56 (void)ctx;
57 #if defined(ARM_MATH_MVEI)
58 if (ctx->buf == NULL)
59 {
60 return ARM_CMSIS_NN_ARG_ERROR;
61 }
62 int32_t *vector_sum_buf = (int32_t *)ctx->buf;
63 #endif
64 const int32_t output_batch = output_dims->n;
65 const int32_t output_height = output_dims->h;
66 const int32_t lhs_rows = input_lhs_dims->w;
67 const int32_t rhs_rows = input_rhs_dims->w;
68 const int32_t rhs_cols = input_rhs_dims->c;
69
70 const int32_t inner_lhs_diff = input_lhs_dims->h >= input_rhs_dims->h ? 0 : lhs_rows * rhs_cols;
71 const int32_t inner_rhs_diff = input_rhs_dims->h >= input_lhs_dims->h ? rhs_rows * rhs_cols : 0;
72 const int32_t outer_lhs_diff = input_lhs_dims->n >= input_rhs_dims->n
73 ? inner_lhs_diff
74 : -((lhs_rows * rhs_cols) - inner_lhs_diff) * input_lhs_dims->h;
75 const int32_t outer_rhs_diff = input_rhs_dims->n >= input_lhs_dims->n ? (rhs_rows * rhs_cols) - inner_rhs_diff
76 : -inner_rhs_diff * input_rhs_dims->h;
77
78 for (int i_out_batch = 0; i_out_batch < output_batch; i_out_batch++)
79 {
80 for (int i_out_height = 0; i_out_height < output_height; i_out_height++)
81 {
82
83 #if defined(ARM_MATH_MVEI)
84 arm_vector_sum_s8(vector_sum_buf,
85 rhs_cols,
86 rhs_rows,
87 input_rhs,
88 bmm_params->fc_params.input_offset,
89 bmm_params->fc_params.filter_offset,
90 NULL);
91 #endif
92 for (int i_lhs_rows = 0; i_lhs_rows < lhs_rows; i_lhs_rows++)
93 {
94 arm_nn_vec_mat_mult_t_s8(input_lhs,
95 input_rhs,
96 #if defined(ARM_MATH_MVEI)
97 vector_sum_buf,
98 #else
99 NULL,
100 #endif
101 NULL,
102 output,
103 bmm_params->fc_params.input_offset,
104 bmm_params->fc_params.output_offset,
105 quant_params->multiplier,
106 quant_params->shift,
107 rhs_cols,
108 rhs_rows,
109 bmm_params->fc_params.activation.min,
110 bmm_params->fc_params.activation.max,
111 1,
112 bmm_params->fc_params.filter_offset);
113
114 input_lhs += rhs_cols;
115 output += rhs_rows;
116 }
117 input_lhs -= inner_lhs_diff;
118 input_rhs += inner_rhs_diff;
119 }
120 input_lhs += outer_lhs_diff;
121 input_rhs += outer_rhs_diff;
122 }
123
124 return ARM_CMSIS_NN_SUCCESS;
125 }
126
127 /**
128 * @} end of Doxygen group
129 */
130