1 /*
2 * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_elementwise_mul_s16_s8.c
22 * Description: Elementwise multiplication of 16 bit input with 8 bit output
23 *
24 * $Date: 20 January 2023
25 * $Revision: V.2.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup BasicMath
39 * @{
40 */
41
42 /*
43 * s16 elementwise multiplication with s8 output
44 *
45 * Refer header file for details.
46 *
47 */
arm_elementwise_mul_s16_s8(const int16_t * input_1_vect,const int16_t * input_2_vect,int8_t * output,const int32_t out_offset,const int32_t out_mult,const int32_t out_shift,const int32_t block_size,const int32_t batch_size,const int32_t batch_offset)48 arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
49 const int16_t *input_2_vect,
50 int8_t *output,
51 const int32_t out_offset,
52 const int32_t out_mult,
53 const int32_t out_shift,
54 const int32_t block_size,
55 const int32_t batch_size,
56 const int32_t batch_offset)
57 {
58
59 for (int i = 0; i < batch_size; i++)
60 {
61 int32_t loop_count = block_size;
62 #if defined(ARM_MATH_MVEI)
63
64 const int16_t *input_1_ptr = input_1_vect;
65 const int16_t *input_2_ptr = input_2_vect;
66 int8_t *output_ptr = output;
67
68 while (loop_count > 0)
69 {
70 mve_pred16_t pred = vctp32q(loop_count);
71
72 int32x4_t input_1 = vldrhq_z_s32(input_1_ptr, pred);
73 int32x4_t input_2 = vldrhq_z_s32(input_2_ptr, pred);
74
75 int32x4_t res_0 = vmulq_s32(input_1, input_2);
76
77 res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift));
78 res_0 = vaddq_n_s32(res_0, out_offset);
79
80 res_0 = vmaxq_s32(res_0, vdupq_n_s32(NN_Q7_MIN));
81 res_0 = vminq_s32(res_0, vdupq_n_s32(NN_Q7_MAX));
82
83 vstrbq_p_s32(output_ptr, res_0, pred);
84 input_1_ptr += 4;
85 input_2_ptr += 4;
86
87 output_ptr += 4;
88 loop_count -= 4;
89 }
90
91 input_1_vect += block_size;
92 input_2_vect += block_size;
93 output += block_size;
94
95 #else
96 #if defined(ARM_MATH_DSP)
97
98 while (loop_count > 1)
99 {
100 int32_t input_1 = arm_nn_read_q15x2_ia(&input_1_vect);
101 int32_t input_2 = arm_nn_read_q15x2_ia(&input_2_vect);
102
103 int32_t mul_res = SMULBB(input_1, input_2);
104 mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
105 mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
106 int32_t mul = (int16_t)(mul_res & 0xFF);
107
108 mul_res = SMULTT(input_1, input_2);
109 mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
110 mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
111 mul |= (int16_t)mul_res << 8;
112
113 arm_nn_write_s8x2_ia(&output, mul);
114 loop_count -= 2;
115 }
116 #endif
117 for (int j = 0; j < loop_count; j++, input_1_vect++, input_2_vect++, output++)
118 {
119 /* C = A * B */
120 int32_t mul_res = (*input_1_vect) * (*input_2_vect);
121 mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
122
123 mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
124
125 *output = (int8_t)mul_res;
126 }
127
128 #endif
129
130 output += (batch_offset - 1) * block_size;
131 }
132 return ARM_CMSIS_NN_SUCCESS;
133 }
134 /**
135 * @} end of BasicMath group
136 */
137