1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_opt_q31.c
4 * Description: Q31 matrix multiplication
5 *
6 * $Date: 3 Nov 2021
7 * $Revision: V1.10.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup MatrixMult
37 @{
38 */
39
40 /**
41 @brief Q31 matrix multiplication.
42 @param[in] pSrcA points to the first input matrix structure
43 @param[in] pSrcB points to the second input matrix structure
44 @param[out] pDst points to output matrix structure
45 @param[in] pState points to the array for storing intermediate results
46 @return execution status
47 - \ref ARM_MATH_SUCCESS : Operation successful
48 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49
50 @par Scaling and Overflow Behavior
51 The function is implemented using an internal 64-bit accumulator.
52 The accumulator has a 2.62 format and maintains full precision of the intermediate
53 multiplication results but provides only a single guard bit. There is no saturation
54 on intermediate additions. Thus, if the accumulator overflows it wraps around and
55 distorts the result. The input signals should be scaled down to avoid intermediate
56 overflows. The input is thus scaled down by log2(numColsA) bits
57 to avoid overflows, as a total of numColsA additions are performed internally.
58 The 2.62 accumulator is right shifted by 31 bits and saturated to 1.31 format to yield the final result.
59 @remark
60 Refer to \ref arm_mat_mult_fast_q31() for a faster but less precise implementation of this function.
61 @remark
62 This function is a faster implementation of arm_mat_mult_q31 for MVE but it is requiring
63 additional storage for intermediate results.
64 */
65 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
66
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70
arm_mat_mult_opt_q31_2x2_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)71 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_2x2_mve(
72 const arm_matrix_instance_q31 * pSrcA,
73 const arm_matrix_instance_q31 * pSrcB,
74 arm_matrix_instance_q31 * pDst)
75 {
76 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
77 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
78 q31_t *pOut = pDst->pData; /* output data matrix pointer */
79 uint32x4_t vecColBOffs;
80 q31_t *pInA0 = pInA;
81 q31_t *pInA1 = pInA0 + MATRIX_DIM2;
82 q63_t acc0, acc1;
83 q31x4_t vecB, vecA0, vecA1;
84 /* enable predication to disable half of vector elements */
85 mve_pred16_t p0 = vctp32q(MATRIX_DIM2);
86
87 vecColBOffs = vidupq_u32((uint32_t)0, 1);
88 vecColBOffs = vecColBOffs * MATRIX_DIM2;
89
90 pInB = pSrcB->pData;
91
92 /* load 1st B column (partial load) */
93 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
94
95 /* load A rows */
96 vecA0 = vldrwq_s32(pInA0);
97 vecA1 = vldrwq_s32(pInA1);
98
99 acc0 = vrmlaldavhq(vecA0, vecB);
100 acc1 = vrmlaldavhq(vecA1, vecB);
101
102 acc0 = asrl(acc0, 23);
103 acc1 = asrl(acc1, 23);
104
105 pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
106 pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
107 pOut++;
108
109 /* move to next B column */
110 pInB = pInB + 1;
111
112 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
113
114 acc0 = vrmlaldavhq(vecA0, vecB);
115 acc1 = vrmlaldavhq(vecA1, vecB);
116
117 acc0 = asrl(acc0, 23);
118 acc1 = asrl(acc1, 23);
119
120 pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
121 pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
122 /*
123 * Return to application
124 */
125 return (ARM_MATH_SUCCESS);
126 }
127
128
129
arm_mat_mult_opt_q31_3x3_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)130 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_3x3_mve(
131 const arm_matrix_instance_q31 * pSrcA,
132 const arm_matrix_instance_q31 * pSrcB,
133 arm_matrix_instance_q31 * pDst)
134 {
135 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
136 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
137 q31_t *pOut = pDst->pData; /* output data matrix pointer */
138 uint32x4_t vecColBOffs;
139 q31_t *pInA0 = pInA;
140 q31_t *pInA1 = pInA0 + MATRIX_DIM3;
141 q31_t *pInA2 = pInA1 + MATRIX_DIM3;
142 q63_t acc0, acc1, acc2;
143 q31x4_t vecB, vecA;
144 /* enable predication to disable last (4th) vector element */
145 mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
146
147 vecColBOffs = vidupq_u32((uint32_t)0, 1);
148 vecColBOffs = vecColBOffs * MATRIX_DIM3;
149
150 pInB = pSrcB->pData;
151
152 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
153
154 vecA = vldrwq_s32(pInA0);
155 acc0 = vrmlaldavhq(vecA, vecB);
156 vecA = vldrwq_s32(pInA1);
157 acc1 = vrmlaldavhq(vecA, vecB);
158 vecA = vldrwq_s32(pInA2);
159 acc2 = vrmlaldavhq(vecA, vecB);
160
161 acc0 = asrl(acc0, 23);
162 acc1 = asrl(acc1, 23);
163 acc2 = asrl(acc2, 23);
164
165 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
166 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
167 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
168 pOut++;
169
170 /* move to next B column */
171 pInB = pInB + 1;
172
173 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
174
175 vecA = vldrwq_s32(pInA0);
176 acc0 = vrmlaldavhq(vecA, vecB);
177 vecA = vldrwq_s32(pInA1);
178 acc1 = vrmlaldavhq(vecA, vecB);
179 vecA = vldrwq_s32(pInA2);
180 acc2 = vrmlaldavhq(vecA, vecB);
181
182 acc0 = asrl(acc0, 23);
183 acc1 = asrl(acc1, 23);
184 acc2 = asrl(acc2, 23);
185
186 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
187 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
188 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
189 pOut++;
190
191 /* move to next B column */
192 pInB = pInB + 1;
193
194 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
195
196 vecA = vldrwq_s32(pInA0);
197 acc0 = vrmlaldavhq(vecA, vecB);
198 vecA = vldrwq_s32(pInA1);
199 acc1 = vrmlaldavhq(vecA, vecB);
200 vecA = vldrwq_s32(pInA2);
201 acc2 = vrmlaldavhq(vecA, vecB);
202
203 acc0 = asrl(acc0, 23);
204 acc1 = asrl(acc1, 23);
205 acc2 = asrl(acc2, 23);
206
207 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
208 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
209 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
210 /*
211 * Return to application
212 */
213 return (ARM_MATH_SUCCESS);
214 }
215
arm_mat_mult_opt_q31_4x4_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)216 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_4x4_mve(
217 const arm_matrix_instance_q31 * pSrcA,
218 const arm_matrix_instance_q31 * pSrcB,
219 arm_matrix_instance_q31 * pDst)
220 {
221 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
222 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
223 q31_t *pOut = pDst->pData; /* output data matrix pointer */
224 uint32x4_t vecColBOffs;
225 q31_t *pInA0 = pInA;
226 q31_t *pInA1 = pInA0 + MATRIX_DIM4;
227 q31_t *pInA2 = pInA1 + MATRIX_DIM4;
228 q31_t *pInA3 = pInA2 + MATRIX_DIM4;
229 q63_t acc0, acc1, acc2, acc3;
230 q31x4_t vecB, vecA;
231
232 vecColBOffs = vidupq_u32((uint32_t)0, 4);
233
234 pInB = pSrcB->pData;
235
236 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
237
238 vecA = vldrwq_s32(pInA0);
239 acc0 = vrmlaldavhq(vecA, vecB);
240 vecA = vldrwq_s32(pInA1);
241 acc1 = vrmlaldavhq(vecA, vecB);
242 vecA = vldrwq_s32(pInA2);
243 acc2 = vrmlaldavhq(vecA, vecB);
244 vecA = vldrwq_s32(pInA3);
245 acc3 = vrmlaldavhq(vecA, vecB);
246
247 acc0 = asrl(acc0, 23);
248 acc1 = asrl(acc1, 23);
249 acc2 = asrl(acc2, 23);
250 acc3 = asrl(acc3, 23);
251
252 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
253 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
254 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
255 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
256 pOut++;
257
258 /* move to next B column */
259 pInB = pInB + 1;
260
261 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
262
263 vecA = vldrwq_s32(pInA0);
264 acc0 = vrmlaldavhq(vecA, vecB);
265 vecA = vldrwq_s32(pInA1);
266 acc1 = vrmlaldavhq(vecA, vecB);
267 vecA = vldrwq_s32(pInA2);
268 acc2 = vrmlaldavhq(vecA, vecB);
269 vecA = vldrwq_s32(pInA3);
270 acc3 = vrmlaldavhq(vecA, vecB);
271
272 acc0 = asrl(acc0, 23);
273 acc1 = asrl(acc1, 23);
274 acc2 = asrl(acc2, 23);
275 acc3 = asrl(acc3, 23);
276
277 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
278 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
279 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
280 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
281
282 pOut++;
283
284 /* move to next B column */
285 pInB = pInB + 1;
286
287 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
288
289 vecA = vldrwq_s32(pInA0);
290 acc0 = vrmlaldavhq(vecA, vecB);
291 vecA = vldrwq_s32(pInA1);
292 acc1 = vrmlaldavhq(vecA, vecB);
293 vecA = vldrwq_s32(pInA2);
294 acc2 = vrmlaldavhq(vecA, vecB);
295 vecA = vldrwq_s32(pInA3);
296 acc3 = vrmlaldavhq(vecA, vecB);
297
298 acc0 = asrl(acc0, 23);
299 acc1 = asrl(acc1, 23);
300 acc2 = asrl(acc2, 23);
301 acc3 = asrl(acc3, 23);
302
303 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
304 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
305 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
306 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
307
308 pOut++;
309
310 /* move to next B column */
311 pInB = pInB + 1;
312
313 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
314
315 vecA = vldrwq_s32(pInA0);
316 acc0 = vrmlaldavhq(vecA, vecB);
317 vecA = vldrwq_s32(pInA1);
318 acc1 = vrmlaldavhq(vecA, vecB);
319 vecA = vldrwq_s32(pInA2);
320 acc2 = vrmlaldavhq(vecA, vecB);
321 vecA = vldrwq_s32(pInA3);
322 acc3 = vrmlaldavhq(vecA, vecB);
323
324 acc0 = asrl(acc0, 23);
325 acc1 = asrl(acc1, 23);
326 acc2 = asrl(acc2, 23);
327 acc3 = asrl(acc3, 23);
328
329 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
330 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
331 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
332 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
333 /*
334 * Return to application
335 */
336 return (ARM_MATH_SUCCESS);
337 }
338
339
arm_mat_mult_opt_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst,q31_t * pState)340 arm_status arm_mat_mult_opt_q31(
341 const arm_matrix_instance_q31 * pSrcA,
342 const arm_matrix_instance_q31 * pSrcB,
343 arm_matrix_instance_q31 * pDst,
344 q31_t *pState)
345 {
346 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
347 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
348 q31_t *pInA2;
349 q31_t *pInB2;
350 q31_t *px; /* Temporary output data matrix pointer */
351 q31_t *px2; /* Temporary output data matrix pointer */
352 uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
353 uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
354 uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
355 uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
356 uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
357 q31_t *pSrcBT = pState; /* input data matrix pointer for transpose */
358 uint32_t blkCnt; /* loop counters */
359 arm_status status; /* Status of matrix multiplication */
360 arm_matrix_instance_q31 BT;
361 #ifdef ARM_MATH_MATRIX_CHECK
362
363 /* Check for matrix mismatch condition */
364 if ((pSrcA->numCols != pSrcB->numRows) ||
365 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) {
366 /* Set status as ARM_MATH_SIZE_MISMATCH */
367 status = ARM_MATH_SIZE_MISMATCH;
368 } else
369 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
370 {
371
372 /* small squared matrix specialized routines */
373 if(numRowsA == numColsB && numColsB == numColsA) {
374 if (numRowsA == 1)
375 {
376 q63_t sum = (q63_t) *pInA * *pInB;
377 pDst->pData[0] = (q31_t)(sum >> 31);
378 return (ARM_MATH_SUCCESS);
379 }
380 else if(numRowsA == 2)
381 return arm_mat_mult_opt_q31_2x2_mve(pSrcA, pSrcB, pDst);
382 else if(numRowsA == 3)
383 return arm_mat_mult_opt_q31_3x3_mve(pSrcA, pSrcB, pDst);
384 else if (numRowsA == 4)
385 return arm_mat_mult_opt_q31_4x4_mve(pSrcA, pSrcB, pDst);
386 }
387
388
389 /*
390 * Matrix transpose
391 */
392 BT.numRows = numColsB;
393 BT.numCols = numRowsB;
394 BT.pData = pSrcBT;
395
396 arm_mat_trans_q31(pSrcB, &BT);
397
398
399 /*
400 * Reset the variables for the usage in the following multiplication process
401 */
402 i = 0;
403 row = numRowsA >> 1;
404 px = pDst->pData;
405 px2 = px + numColsB;
406
407 /*
408 * main loop
409 * compute 2 x 2 output blocks
410 * with dot products (Matrix A rows * Transposed MAtrix B rows)
411 */
412 while (row > 0u) {
413 /*
414 * For every row wise process, the column loop counter is to be initiated
415 * Compute 2 columns and 2 rows in parrallel
416 */
417 col = numColsB >> 1;
418 j = 0;
419
420 /*
421 * column pair loop
422 */
423 while (col > 0u) {
424 q31_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
425 q31x4_t vecA, vecA2, vecB, vecB2;
426 q63_t acc0, acc1, acc2, acc3;
427
428 /*
429 * Initiate the pointers
430 * - 2 x consecutive Matrix A rows (i increment is 2 x numColsA)
431 * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
432 */
433 pInA = pSrcA->pData + i;
434 pInA2 = pInA + numColsA;
435 pInB = pSrcBT + j;
436 pInB2 = pInB + numRowsB;
437
438
439 pSrcAVec = (q31_t const *) pInA;
440 pSrcA2Vec = (q31_t const *) pInA2;
441 pSrcBVec = (q31_t const *) pInB;
442 pSrcB2Vec = (q31_t const *) pInB2;
443
444 acc0 = 0LL;
445 acc1 = 0LL;
446 acc2 = 0LL;
447 acc3 = 0LL;
448
449 /* load scheduling */
450 vecA = vld1q(pSrcAVec);
451 pSrcAVec += 4;
452
453 blkCnt = (numColsA / 4);
454 while (blkCnt > 0U) {
455 vecB = vld1q(pSrcBVec);
456 pSrcBVec += 4;
457 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
458 vecA2 = vld1q(pSrcA2Vec);
459 pSrcA2Vec += 4;
460 acc1 = vrmlaldavhaq(acc1, vecA2, vecB);
461 vecB2 = vld1q(pSrcB2Vec);
462 pSrcB2Vec += 4;
463 acc2 = vrmlaldavhaq(acc2, vecA, vecB2);
464 vecA = vld1q(pSrcAVec);
465 pSrcAVec += 4;
466 acc3 = vrmlaldavhaq(acc3, vecA2, vecB2);
467
468 blkCnt--;
469 }
470 /*
471 * tail
472 * (will be merged thru tail predication)
473 */
474 blkCnt = (numColsA & 3);
475 if (blkCnt > 0U) {
476 mve_pred16_t p0 = vctp32q(blkCnt);
477 vecB = vld1q(pSrcBVec);
478 acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
479 vecA2 = vld1q(pSrcA2Vec);
480 acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0);
481 vecB2 = vld1q(pSrcB2Vec);
482 acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0);
483 vecA = vld1q(pSrcAVec);
484 acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0);
485 }
486
487 /* Convert to 1.31 */
488 acc0 = asrl(acc0, 23);
489 acc1 = asrl(acc1, 23);
490 acc2 = asrl(acc2, 23);
491 acc3 = asrl(acc3, 23);
492
493 /* Store the results (2 x 2 block) in the destination buffer */
494 *px++ = (q31_t) acc0;
495 *px++ = (q31_t) acc2;
496 *px2++ = (q31_t) acc1;
497 *px2++ = (q31_t) acc3;
498
499 j += numRowsB * 2;
500 /*
501 * Decrement the column pair loop counter
502 */
503 col--;
504
505 }
506
507 i = i + numColsA * 2;
508 px = px2 + (numColsB & 1u);
509 px2 = px + numColsB;
510 /*
511 * Decrement the row pair loop counter
512 */
513 row--;
514 }
515
516 /*
517 * Compute remaining row and/or column below
518 */
519 if (numColsB & 1u) {
520 row = numRowsA & (~0x1); //avoid redundant computation
521 px = pDst->pData + numColsB - 1;
522 i = 0;
523
524 /*
525 * row loop
526 */
527 while (row > 0) {
528 q31_t const *pSrcAVec, *pSrcBVec;
529 q31x4_t vecA, vecB;
530 q63_t acc0;
531
532 /*
533 * point to last column in matrix B
534 */
535 pInB = pSrcBT + numRowsB * (numColsB - 1);
536 pInA = pSrcA->pData + i;
537
538 pSrcAVec = (q31_t const *) pInA;
539 pSrcBVec = (q31_t const *) pInB;
540
541 /* single dot-product */
542 acc0 = 0LL;
543 blkCnt = (numColsA / 4);
544 while (blkCnt > 0U) {
545 vecA = vld1q(pSrcAVec);
546 pSrcAVec += 4;
547 vecB = vld1q(pSrcBVec);
548 pSrcBVec += 4;
549 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
550
551 blkCnt--;
552 }
553 /*
554 * tail
555 * (will be merged thru tail predication)
556 */
557 blkCnt = (numColsA & 3);
558 if (blkCnt > 0U) {
559 mve_pred16_t p0 = vctp32q(blkCnt);
560 vecA = vld1q(pSrcAVec);
561 vecB = vld1q(pSrcBVec);
562 acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
563 }
564
565 acc0 = asrl(acc0, 23);
566 *px = (q31_t) acc0;
567
568 px += numColsB;
569
570 i += numColsA;
571 /*
572 * Decrement the row loop counter
573 */
574 row--;
575 }
576 }
577
578 if (numRowsA & 1u) {
579 col = numColsB;
580 i = 0u;
581 /*
582 * point to last row in output matrix
583 */
584 px = pDst->pData + (numColsB) * (numRowsA - 1);
585 /*
586 * col loop
587 */
588 while (col > 0) {
589 q31_t const *pSrcAVec, *pSrcBVec;
590 q31x4_t vecA, vecB;
591 q63_t acc0;
592
593 /*
594 * point to last row in matrix A
595 */
596 pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
597 pInB = pSrcBT + i;
598
599 /*
600 * Set the variable sum, that acts as accumulator, to zero
601 */
602 pSrcAVec = (q31_t const *) pInA;
603 pSrcBVec = (q31_t const *) pInB;
604 acc0 = 0LL;
605
606 blkCnt = (numColsA / 4);
607 while (blkCnt > 0U) {
608 vecA = vld1q(pSrcAVec);
609 pSrcAVec += 4;
610 vecB = vld1q(pSrcBVec);
611 pSrcBVec += 4;
612 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
613
614 blkCnt--;
615 }
616 /*
617 * tail
618 * (will be merged thru tail predication)
619 */
620 blkCnt = (numColsA & 3);
621 if (blkCnt > 0U) {
622 mve_pred16_t p0 = vctp32q(blkCnt);
623 vecA = vld1q(pSrcAVec);
624 vecB = vld1q(pSrcBVec);
625 acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
626 }
627
628 acc0 = asrl(acc0, 23);
629 *px++ = (q31_t) acc0;
630
631 i += numColsA;
632 /*
633 * Decrement the col loop counter
634 */
635 col--;
636 }
637 }
638 /* Set status as ARM_MATH_SUCCESS */
639 status = ARM_MATH_SUCCESS;
640 }
641 /*
642 * Return to application
643 */
644 return (status);
645 }
646
647 #else
arm_mat_mult_opt_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst,q31_t * pState)648 arm_status arm_mat_mult_opt_q31(
649 const arm_matrix_instance_q31 * pSrcA,
650 const arm_matrix_instance_q31 * pSrcB,
651 arm_matrix_instance_q31 * pDst,
652 q31_t *pState)
653 {
654 q31_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
655 q31_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
656 q31_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
657 q31_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
658 q31_t *pOut = pDst->pData; /* Output data matrix pointer */
659 q31_t *px; /* Temporary output data matrix pointer */
660 q63_t sum; /* Accumulator */
661 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
662 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
663 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
664 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
665 arm_status status; /* Status of matrix multiplication */
666 (void)pState;
667 #ifdef ARM_MATH_MATRIX_CHECK
668
669 /* Check for matrix mismatch condition */
670 if ((pSrcA->numCols != pSrcB->numRows) ||
671 (pSrcA->numRows != pDst->numRows) ||
672 (pSrcB->numCols != pDst->numCols) )
673 {
674 /* Set status as ARM_MATH_SIZE_MISMATCH */
675 status = ARM_MATH_SIZE_MISMATCH;
676 }
677 else
678
679 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
680
681 {
682 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
683 /* row loop */
684 do
685 {
686 /* Output pointer is set to starting address of row being processed */
687 px = pOut + i;
688
689 /* For every row wise process, column loop counter is to be initiated */
690 col = numColsB;
691
692 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
693 pIn2 = pSrcB->pData;
694
695 /* column loop */
696 do
697 {
698 /* Set the variable sum, that acts as accumulator, to zero */
699 sum = 0;
700
701 /* Initialize pointer pIn1 to point to starting address of column being processed */
702 pIn1 = pInA;
703
704 #if defined (ARM_MATH_LOOPUNROLL)
705
706 /* Loop unrolling: Compute 4 MACs at a time. */
707 colCnt = numColsA >> 2U;
708
709 /* matrix multiplication */
710 while (colCnt > 0U)
711 {
712 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
713
714 /* Perform the multiply-accumulates */
715 sum += (q63_t) *pIn1++ * *pIn2;
716 pIn2 += numColsB;
717
718 sum += (q63_t) *pIn1++ * *pIn2;
719 pIn2 += numColsB;
720
721 sum += (q63_t) *pIn1++ * *pIn2;
722 pIn2 += numColsB;
723
724 sum += (q63_t) *pIn1++ * *pIn2;
725 pIn2 += numColsB;
726
727 /* Decrement loop counter */
728 colCnt--;
729 }
730
731 /* Loop unrolling: Compute remaining MACs */
732 colCnt = numColsA % 0x4U;
733
734 #else
735
736 /* Initialize cntCnt with number of columns */
737 colCnt = numColsA;
738
739 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
740
741 while (colCnt > 0U)
742 {
743 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
744
745 /* Perform the multiply-accumulates */
746 sum += (q63_t) *pIn1++ * *pIn2;
747 pIn2 += numColsB;
748
749 /* Decrement loop counter */
750 colCnt--;
751 }
752
753 /* Convert result from 2.62 to 1.31 format and store in destination buffer */
754 *px++ = (q31_t) (sum >> 31);
755
756 /* Decrement column loop counter */
757 col--;
758
759 /* Update pointer pIn2 to point to starting address of next column */
760 pIn2 = pInB + (numColsB - col);
761
762 } while (col > 0U);
763
764 /* Update pointer pInA to point to starting address of next row */
765 i = i + numColsB;
766 pInA = pInA + numColsA;
767
768 /* Decrement row loop counter */
769 row--;
770
771 } while (row > 0U);
772
773 /* Set status as ARM_MATH_SUCCESS */
774 status = ARM_MATH_SUCCESS;
775 }
776
777 /* Return to application */
778 return (status);
779 }
780 #endif /* defined(ARM_MATH_MVEI) */
781
782 /**
783 @} end of MatrixMult group
784 */
785