1 /*
2 * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_q7_q15.c
22 * Description: Matrix-multiplication function for convolution
23 *
24 * $Date: January 26, 2021
25 * $Revision: V.1.0.2
26 *
27 * Target Processor: Cortex-M cores
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @brief Matrix-multiplication function for convolution.
35 *
36 * @details Refer to header file for details.
37 *
38 */
39
arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,const q15_t * pInBuffer,const uint16_t ch_im_out,const uint16_t numCol_A,const uint16_t bias_shift,const uint16_t out_shift,const q7_t * bias,q7_t * pOut)40 q7_t *arm_nn_mat_mult_kernel_q7_q15(const q7_t *pA,
41 const q15_t *pInBuffer,
42 const uint16_t ch_im_out,
43 const uint16_t numCol_A,
44 const uint16_t bias_shift,
45 const uint16_t out_shift,
46 const q7_t *bias,
47 q7_t *pOut)
48 {
49 #if defined(ARM_MATH_DSP)
50 /* set up the second output pointers */
51 q7_t *pOut2 = pOut + ch_im_out;
52 const q7_t *pBias = bias;
53
54 uint16_t rowCnt = ch_im_out >> 1;
55 /* this loop over rows in A */
56 while (rowCnt)
57 {
58 /* setup pointers for B */
59 const q15_t *pB = pInBuffer;
60 const q15_t *pB2 = pB + numCol_A;
61
62 /* align the second pointer for A */
63 const q7_t *pA2 = pA + numCol_A;
64
65 /* init the sum with bias */
66 q31_t sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
67 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
68 q31_t sum3 = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
69 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
70
71 uint16_t colCnt = numCol_A >> 2;
72 /* accumulate over the vector */
73 while (colCnt)
74 {
75 q31_t inA11, inA12, inA21, inA22;
76
77 q31_t inB1 = arm_nn_read_q15x2_ia(&pB);
78 q31_t inB2 = arm_nn_read_q15x2_ia(&pB2);
79
80 pA = read_and_pad(pA, &inA11, &inA12);
81 pA2 = read_and_pad(pA2, &inA21, &inA22);
82
83 sum = __SMLAD(inA11, inB1, sum);
84 sum2 = __SMLAD(inA11, inB2, sum2);
85 sum3 = __SMLAD(inA21, inB1, sum3);
86 sum4 = __SMLAD(inA21, inB2, sum4);
87
88 inB1 = arm_nn_read_q15x2_ia(&pB);
89 inB2 = arm_nn_read_q15x2_ia(&pB2);
90
91 sum = __SMLAD(inA12, inB1, sum);
92 sum2 = __SMLAD(inA12, inB2, sum2);
93 sum3 = __SMLAD(inA22, inB1, sum3);
94 sum4 = __SMLAD(inA22, inB2, sum4);
95
96 colCnt--;
97 } /* while over colCnt */
98 colCnt = numCol_A & 0x3;
99 while (colCnt)
100 {
101 q7_t inA1 = *pA++;
102 q15_t inB1 = *pB++;
103 q7_t inA2 = *pA2++;
104 q15_t inB2 = *pB2++;
105
106 sum += inA1 * inB1;
107 sum2 += inA1 * inB2;
108 sum3 += inA2 * inB1;
109 sum4 += inA2 * inB2;
110 colCnt--;
111 } /* while over colCnt */
112 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
113 *pOut++ = (q7_t)__SSAT((sum3 >> out_shift), 8);
114 *pOut2++ = (q7_t)__SSAT((sum2 >> out_shift), 8);
115 *pOut2++ = (q7_t)__SSAT((sum4 >> out_shift), 8);
116
117 /* skip the row computed with A2 */
118 pA += numCol_A;
119 rowCnt--;
120 } /* for over ch_im_out */
121
122 /* compute left-over row if any */
123 if (ch_im_out & 0x1)
124 {
125 /* setup pointers for B */
126 const q15_t *pB = pInBuffer;
127 const q15_t *pB2 = pB + numCol_A;
128
129 /* load the bias */
130 q31_t sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
131 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
132
133 uint16_t colCnt = numCol_A >> 2;
134 while (colCnt)
135 {
136 q31_t inA11, inA12;
137
138 q31_t inB1 = arm_nn_read_q15x2_ia(&pB);
139 q31_t inB2 = arm_nn_read_q15x2_ia(&pB2);
140
141 pA = read_and_pad(pA, &inA11, &inA12);
142
143 sum = __SMLAD(inA11, inB1, sum);
144 sum2 = __SMLAD(inA11, inB2, sum2);
145
146 inB1 = arm_nn_read_q15x2_ia(&pB);
147 inB2 = arm_nn_read_q15x2_ia(&pB2);
148
149 sum = __SMLAD(inA12, inB1, sum);
150 sum2 = __SMLAD(inA12, inB2, sum2);
151
152 colCnt--;
153 }
154 colCnt = numCol_A & 0x3;
155 while (colCnt)
156 {
157 q7_t inA1 = *pA++;
158 q15_t inB1 = *pB++;
159 q15_t inB2 = *pB2++;
160
161 sum += inA1 * inB1;
162 sum2 += inA1 * inB2;
163 colCnt--;
164 }
165
166 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
167 *pOut2++ = (q7_t)__SSAT((sum2 >> out_shift), 8);
168 }
169
170 pOut += ch_im_out;
171
172 /* return the new output pointer with offset */
173 return pOut;
174 #else
175 (void)pA;
176 (void)pInBuffer;
177 (void)ch_im_out;
178 (void)numCol_A;
179 (void)bias_shift;
180 (void)out_shift;
181 (void)bias;
182 (void)pOut;
183 /* To be completed */
184 return NULL;
185 #endif /* ARM_MATH_DSP */
186 }
187