1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_q15.c
4 * Description: Q15 matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions.h"
31
32 /**
33 * @ingroup groupMatrix
34 */
35
36
37
38 /**
39 * @addtogroup MatrixVectMult
40 * @{
41 */
42
43 /**
44 * @brief Q15 matrix and vector multiplication.
45 * @param[in] *pSrcMat points to the input matrix structure
46 * @param[in] *pVec points to input vector
47 * @param[out] *pDst points to output vector
48 */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
50
51 #include "arm_helium_utils.h"
52
arm_mat_vec_mult_q15(const arm_matrix_instance_q15 * pSrcMat,const q15_t * pSrcVec,q15_t * pDstVec)53 void arm_mat_vec_mult_q15(
54 const arm_matrix_instance_q15 * pSrcMat,
55 const q15_t *pSrcVec,
56 q15_t *pDstVec)
57 {
58 const q15_t *pMatSrc = pSrcMat->pData;
59 const q15_t *pMat0, *pMat1;
60 uint32_t numRows = pSrcMat->numRows;
61 uint32_t numCols = pSrcMat->numCols;
62 q15_t *px;
63 int32_t row;
64 uint16_t blkCnt; /* loop counters */
65
66 row = numRows;
67 px = pDstVec;
68
69 /*
70 * compute 3x64-bit accumulators per loop
71 */
72 while (row >= 3)
73 {
74 q15_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
75 const q15_t *pMat2;
76 q15_t const *pSrcVecPtr = pSrcVec;
77 q63_t acc0, acc1, acc2;
78 q15x8_t vecMatA0, vecMatA1, vecMatA2, vecIn;
79
80
81 pVec = pSrcVec;
82 /*
83 * Initialize the pointer pIn1 to point to the starting address of the column being processed
84 */
85 pMat0 = pMatSrc;
86 pMat1 = pMat0 + numCols;
87 pMat2 = pMat1 + numCols;
88
89 acc0 = 0LL;
90 acc1 = 0LL;
91 acc2 = 0LL;
92
93 pMat0Vec = pMat0;
94 pMat1Vec = pMat1;
95 pMat2Vec = pMat2;
96 pVec = pSrcVecPtr;
97
98 blkCnt = numCols >> 3;
99 while (blkCnt > 0U)
100 {
101 vecMatA0 = vld1q(pMat0Vec);
102 pMat0Vec += 8;
103 vecMatA1 = vld1q(pMat1Vec);
104 pMat1Vec += 8;
105 vecMatA2 = vld1q(pMat2Vec);
106 pMat2Vec += 8;
107 vecIn = vld1q(pVec);
108 pVec += 8;
109
110 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
111 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
112 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
113
114 blkCnt--;
115 }
116 /*
117 * tail
118 * (will be merged thru tail predication)
119 */
120 blkCnt = numCols & 7;
121 if (blkCnt > 0U)
122 {
123 mve_pred16_t p0 = vctp16q(blkCnt);
124
125 vecMatA0 = vld1q(pMat0Vec);
126 vecMatA1 = vld1q(pMat1Vec);
127 vecMatA2 = vld1q(pMat2Vec);
128 vecIn = vldrhq_z_s16(pVec, p0);
129
130 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
131 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
132 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
133 }
134
135 *px++ = MVE_ASRL_SAT16(acc0, 15);
136 *px++ = MVE_ASRL_SAT16(acc1, 15);
137 *px++ = MVE_ASRL_SAT16(acc2, 15);
138
139 pMatSrc += numCols * 3;
140 /*
141 * Decrement the row loop counter
142 */
143 row -= 3;
144 }
145
146 /*
147 * process any remaining rows pair
148 */
149 if (row >= 2)
150 {
151 q15_t const *pMat0Vec, *pMat1Vec, *pVec;
152 q15_t const *pSrcVecPtr = pSrcVec;
153 q63_t acc0, acc1;
154 q15x8_t vecMatA0, vecMatA1, vecIn;
155
156 /*
157 * For every row wise process, the pInVec pointer is set
158 * to the starting address of the vector
159 */
160 pVec = pSrcVec;
161
162 /*
163 * Initialize the pointer pIn1 to point to the starting address of the column being processed
164 */
165 pMat0 = pMatSrc;
166 pMat1 = pMat0 + numCols;
167
168 acc0 = 0LL;
169 acc1 = 0LL;
170
171 pMat0Vec = pMat0;
172 pMat1Vec = pMat1;
173 pVec = pSrcVecPtr;
174
175 blkCnt = numCols >> 3;
176 while (blkCnt > 0U)
177 {
178 vecMatA0 = vld1q(pMat0Vec);
179 pMat0Vec += 8;
180 vecMatA1 = vld1q(pMat1Vec);
181 pMat1Vec += 8;
182 vecIn = vld1q(pVec);
183 pVec += 8;
184
185 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
186 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
187
188 blkCnt--;
189 }
190
191 /*
192 * tail
193 * (will be merged thru tail predication)
194 */
195 blkCnt = numCols & 7;
196 if (blkCnt > 0U)
197 {
198 mve_pred16_t p0 = vctp16q(blkCnt);
199
200 vecMatA0 = vld1q(pMat0Vec);
201 vecMatA1 = vld1q(pMat1Vec);
202 vecIn = vldrhq_z_s16(pVec, p0);
203
204 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
205 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
206 }
207
208 *px++ = MVE_ASRL_SAT16(acc0, 15);
209 *px++ = MVE_ASRL_SAT16(acc1, 15);
210
211 pMatSrc += numCols * 2;
212 /*
213 * Decrement the row loop counter
214 */
215 row -= 2;
216 }
217
218 if (row >= 1)
219 {
220 q15_t const *pMat0Vec, *pVec;
221 q15_t const *pSrcVecPtr = pSrcVec;
222 q63_t acc0;
223 q15x8_t vecMatA0, vecIn;
224
225 /*
226 * For every row wise process, the pInVec pointer is set
227 * to the starting address of the vector
228 */
229 pVec = pSrcVec;
230
231 /*
232 * Initialize the pointer pIn1 to point to the starting address of the column being processed
233 */
234 pMat0 = pMatSrc;
235
236 acc0 = 0LL;
237
238 pMat0Vec = pMat0;
239 pVec = pSrcVecPtr;
240
241 blkCnt = numCols >> 3;
242 while (blkCnt > 0U)
243 {
244 vecMatA0 = vld1q(pMat0Vec);
245 pMat0Vec += 8;
246 vecIn = vld1q(pVec);
247 pVec += 8;
248 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
249 blkCnt--;
250 }
251 /*
252 * tail
253 * (will be merged thru tail predication)
254 */
255 blkCnt = numCols & 7;
256 if (blkCnt > 0U)
257 {
258 mve_pred16_t p0 = vctp16q(blkCnt);
259
260 vecMatA0 = vld1q(pMat0Vec);
261 vecIn = vldrhq_z_s16(pVec, p0);
262 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
263 }
264 *px++ = MVE_ASRL_SAT16(acc0, 15);
265 }
266 }
267
268 #else
arm_mat_vec_mult_q15(const arm_matrix_instance_q15 * pSrcMat,const q15_t * pVec,q15_t * pDst)269 void arm_mat_vec_mult_q15(const arm_matrix_instance_q15 *pSrcMat, const q15_t *pVec, q15_t *pDst)
270 {
271 uint32_t numRows = pSrcMat->numRows;
272 uint32_t numCols = pSrcMat->numCols;
273 const q15_t *pSrcA = pSrcMat->pData;
274 const q15_t *pInA1; /* input data matrix pointer A of Q15 type */
275 const q15_t *pInA2; /* input data matrix pointer A of Q15 type */
276 const q15_t *pInA3; /* input data matrix pointer A of Q15 type */
277 const q15_t *pInA4; /* input data matrix pointer A of Q15 type */
278 const q15_t *pInVec; /* input data matrix pointer B of Q15 type */
279 q15_t *px; /* Temporary output data matrix pointer */
280 uint16_t i, row, colCnt; /* loop counters */
281 q31_t matData, matData2, vecData, vecData2;
282
283
284 /* Process 4 rows at a time */
285 row = numRows >> 2;
286 i = 0u;
287 px = pDst;
288
289 /* The following loop performs the dot-product of each row in pSrcA with the vector */
290 /* row loop */
291 while (row > 0) {
292 /* Initialize accumulators */
293 q63_t sum1 = 0;
294 q63_t sum2 = 0;
295 q63_t sum3 = 0;
296 q63_t sum4 = 0;
297
298 /* For every row wise process, the pInVec pointer is set
299 ** to the starting address of the vector */
300 pInVec = pVec;
301
302 /* Loop unrolling: process 2 columns per iteration */
303 colCnt = numCols >> 1;
304
305 /* Initialize pointers to the starting address of the column being processed */
306 pInA1 = pSrcA + i;
307 pInA2 = pInA1 + numCols;
308 pInA3 = pInA2 + numCols;
309 pInA4 = pInA3 + numCols;
310
311 // Main loop: matrix-vector multiplication
312 while (colCnt > 0u) {
313 // Read 2 values from vector
314 vecData = read_q15x2_ia (&pInVec);
315
316 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
317 matData = read_q15x2_ia (&pInA1);
318 sum1 = __SMLALD(matData, vecData, sum1);
319 matData = read_q15x2_ia (&pInA2);
320 sum2 = __SMLALD(matData, vecData, sum2);
321 matData = read_q15x2_ia (&pInA3);
322 sum3 = __SMLALD(matData, vecData, sum3);
323 matData = read_q15x2_ia (&pInA4);
324 sum4 = __SMLALD(matData, vecData, sum4);
325
326 // Decrement the loop counter
327 colCnt--;
328 }
329
330 /* process any remaining columns */
331 colCnt = numCols & 1u;
332 if (numCols & 1u) {
333 vecData = *pInVec++;
334 sum1 += (q63_t)*pInA1++ * vecData;
335 sum2 += (q63_t)*pInA2++ * vecData;
336 sum3 += (q63_t)*pInA3++ * vecData;
337 sum4 += (q63_t)*pInA4++ * vecData;
338 }
339
340 /* Saturate and store the result in the destination buffer */
341 *px++ = (q15_t)(__SSAT((sum1 >> 15), 16));
342 *px++ = (q15_t)(__SSAT((sum2 >> 15), 16));
343 *px++ = (q15_t)(__SSAT((sum3 >> 15), 16));
344 *px++ = (q15_t)(__SSAT((sum4 >> 15), 16));
345
346 i = i + numCols * 4;
347
348 /* Decrement the row loop counter */
349 row--;
350 }
351
352 /* process any remaining rows */
353 row = numRows & 3u;
354 while (row > 0) {
355
356 q63_t sum = 0;
357 pInVec = pVec;
358 pInA1 = pSrcA + i;
359
360 // loop unrolling - process 4 elements at a time
361 colCnt = numCols >> 2;
362
363 while (colCnt > 0) {
364 vecData = read_q15x2_ia (&pInVec);
365 vecData2 = read_q15x2_ia (&pInVec);
366 matData = read_q15x2_ia (&pInA1);
367 matData2 = read_q15x2_ia (&pInA1);
368 sum = __SMLALD(matData, vecData, sum);
369 sum = __SMLALD(matData2, vecData2, sum);
370 colCnt--;
371 }
372
373 // process remainder of row
374 colCnt = numCols & 3u;
375 while (colCnt > 0) {
376 sum += (q63_t)*pInA1++ * *pInVec++;
377 colCnt--;
378 }
379 *px++ = (q15_t)(__SSAT((sum >> 15), 16));
380 i = i + numCols;
381 row--;
382 }
383 }
384 #endif /* defined(ARM_MATH_MVEI) */
385
386 /**
387 * @} end of MatrixMult group
388 */
389