1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_q15.c
4 * Description: Q15 matrix multiplication
5 *
6 * $Date: 3 Nov 2021
7 * $Revision: V1.10.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup MatrixMult
37 @{
38 */
39
40 /**
41 @brief Q15 matrix multiplication.
42 @param[in] pSrcA points to the first input matrix structure
43 @param[in] pSrcB points to the second input matrix structure
44 @param[out] pDst points to output matrix structure
45 @param[in] pState points to the array for storing intermediate results
46 @return execution status
47 - \ref ARM_MATH_SUCCESS : Operation successful
48 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49
50 @par Scaling and Overflow Behavior
51 The function is implemented using an internal 64-bit accumulator. The inputs to the
52 multiplications are in 1.15 format and multiplications yield a 2.30 result.
53 The 2.30 intermediate results are accumulated in a 64-bit accumulator in 34.30 format.
54 This approach provides 33 guard bits and there is no risk of overflow.
55 The 34.30 result is then truncated to 34.15 format by discarding the low 15 bits
56 and then saturated to 1.15 format.
57 @par
58 Refer to \ref arm_mat_mult_fast_q15() for a faster but less precise version of this function.
59
60 @par pState
61 pState will contain the transpose of pSrcB
62 */
63 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
64
65 #define MVE_ASRL_SAT16(acc, shift) ((sqrshrl_sat48(acc, -(32-shift)) >> 32) & 0xffffffff)
66
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70
arm_mat_mult_q15_2x2_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)71 __STATIC_INLINE arm_status arm_mat_mult_q15_2x2_mve(
72 const arm_matrix_instance_q15 * pSrcA,
73 const arm_matrix_instance_q15 * pSrcB,
74 arm_matrix_instance_q15 * pDst)
75 {
76 q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
77 q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
78 q15_t *pOut = pDst->pData; /* output data matrix pointer */
79 uint16x8_t vecColBOffs;
80 q15_t *pInA0 = pInA;
81 q15_t *pInA1 = pInA0 + MATRIX_DIM2;
82 q63_t acc0, acc1;
83 q15x8_t vecB, vecA0, vecA1;
84 mve_pred16_t p0 = vctp16q(MATRIX_DIM2);
85
86 vecColBOffs = vidupq_u16((uint32_t)0, 2); /* MATRIX_DIM2 */
87
88 pInB = pSrcB->pData;
89
90 vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
91
92 vecA0 = vldrhq_s16(pInA0);
93 vecA1 = vldrhq_s16(pInA1);
94
95 acc0 = vmlaldavq(vecA0, vecB);
96 acc1 = vmlaldavq(vecA1, vecB);
97
98 acc0 = asrl(acc0, 15);
99 acc1 = asrl(acc1, 15);
100
101 pOut[0 * MATRIX_DIM2] = (q15_t) __SSAT(acc0, 16);
102 pOut[1 * MATRIX_DIM2] = (q15_t) __SSAT(acc1, 16);
103 pOut++;
104
105 /* move to next B column */
106 pInB = pInB + 1;
107
108 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
109
110 acc0 = vmlaldavq(vecA0, vecB);
111 acc1 = vmlaldavq(vecA1, vecB);
112
113 acc0 = asrl(acc0, 15);
114 acc1 = asrl(acc1, 15);
115
116 pOut[0 * MATRIX_DIM2] = (q15_t) __SSAT(acc0, 16);
117 pOut[1 * MATRIX_DIM2] = (q15_t) __SSAT(acc1, 16);
118
119 /*
120 * Return to application
121 */
122 return (ARM_MATH_SUCCESS);
123 }
124
125
126
arm_mat_mult_q15_3x3_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)127 __STATIC_INLINE arm_status arm_mat_mult_q15_3x3_mve(
128 const arm_matrix_instance_q15 * pSrcA,
129 const arm_matrix_instance_q15 * pSrcB,
130 arm_matrix_instance_q15 * pDst)
131 {
132 q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
133 q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
134 q15_t *pOut = pDst->pData; /* output data matrix pointer */
135 uint16x8_t vecColBOffs;
136 q15_t *pInA0 = pInA;
137 q15_t *pInA1 = pInA0 + MATRIX_DIM3;
138 q15_t *pInA2 = pInA1 + MATRIX_DIM3;
139 q63_t acc0, acc1, acc2;
140 q15x8_t vecB, vecA0, vecA1, vecA2;
141 mve_pred16_t p0 = vctp16q(MATRIX_DIM3);
142
143 vecColBOffs = vidupq_u16((uint32_t)0, 1);
144 vecColBOffs = vecColBOffs * MATRIX_DIM3;
145
146 pInB = pSrcB->pData;
147
148 vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
149
150 vecA0 = vldrhq_s16(pInA0);
151 vecA1 = vldrhq_s16(pInA1);
152 vecA2 = vldrhq_s16(pInA2);
153
154 acc0 = vmlaldavq(vecA0, vecB);
155 acc1 = vmlaldavq(vecA1, vecB);
156 acc2 = vmlaldavq(vecA2, vecB);
157
158 acc0 = asrl(acc0, 15);
159 acc1 = asrl(acc1, 15);
160 acc2 = asrl(acc2, 15);
161
162 pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
163 pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
164 pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
165 pOut++;
166
167 /* move to next B column */
168 pInB = pInB + 1;
169
170 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
171
172 acc0 = vmlaldavq(vecA0, vecB);
173 acc1 = vmlaldavq(vecA1, vecB);
174 acc2 = vmlaldavq(vecA2, vecB);
175
176 acc0 = asrl(acc0, 15);
177 acc1 = asrl(acc1, 15);
178 acc2 = asrl(acc2, 15);
179
180 pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
181 pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
182 pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
183 pOut++;
184
185 /* move to next B column */
186 pInB = pInB + 1;
187
188 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
189
190 acc0 = vmlaldavq(vecA0, vecB);
191 acc1 = vmlaldavq(vecA1, vecB);
192 acc2 = vmlaldavq(vecA2, vecB);
193
194 acc0 = asrl(acc0, 15);
195 acc1 = asrl(acc1, 15);
196 acc2 = asrl(acc2, 15);
197
198 pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
199 pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
200 pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
201 /*
202 * Return to application
203 */
204 return (ARM_MATH_SUCCESS);
205 }
206
207
arm_mat_mult_q15_4x4_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)208 __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
209 const arm_matrix_instance_q15 * pSrcA,
210 const arm_matrix_instance_q15 * pSrcB,
211 arm_matrix_instance_q15 * pDst)
212 {
213 q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
214 q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
215 q15_t *pOut = pDst->pData; /* output data matrix pointer */
216 uint16x8_t vecColBOffs;
217 q15_t *pInA0 = pInA;
218 q15_t *pInA1 = pInA0 + MATRIX_DIM4;
219 q15_t *pInA2 = pInA1 + MATRIX_DIM4;
220 q15_t *pInA3 = pInA2 + MATRIX_DIM4;
221 q63_t acc0, acc1, acc2, acc3;
222 q15x8_t vecB, vecA0, vecA1, vecA2, vecA3;
223 mve_pred16_t p0 = vctp16q(MATRIX_DIM4);
224
225 vecColBOffs = vidupq_u16((uint32_t)0, 4);
226
227 pInB = pSrcB->pData;
228
229 vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
230
231 vecA0 = vldrhq_s16(pInA0);
232 vecA1 = vldrhq_s16(pInA1);
233 vecA2 = vldrhq_s16(pInA2);
234 vecA3 = vldrhq_s16(pInA3);
235
236 acc0 = vmlaldavq(vecA0, vecB);
237 acc1 = vmlaldavq(vecA1, vecB);
238 acc2 = vmlaldavq(vecA2, vecB);
239 acc3 = vmlaldavq(vecA3, vecB);
240
241 acc0 = asrl(acc0, 15);
242 acc1 = asrl(acc1, 15);
243 acc2 = asrl(acc2, 15);
244 acc3 = asrl(acc3, 15);
245
246 pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
247 pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
248 pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
249 pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
250 pOut++;
251
252 /* move to next B column */
253 pInB = pInB + 1;
254
255 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
256
257 acc0 = vmlaldavq(vecA0, vecB);
258 acc1 = vmlaldavq(vecA1, vecB);
259 acc2 = vmlaldavq(vecA2, vecB);
260 acc3 = vmlaldavq(vecA3, vecB);
261
262 acc0 = asrl(acc0, 15);
263 acc1 = asrl(acc1, 15);
264 acc2 = asrl(acc2, 15);
265 acc3 = asrl(acc3, 15);
266
267 pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
268 pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
269 pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
270 pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
271
272 pOut++;
273
274 /* move to next B column */
275 pInB = pInB + 1;
276
277 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
278
279 acc0 = vmlaldavq(vecA0, vecB);
280 acc1 = vmlaldavq(vecA1, vecB);
281 acc2 = vmlaldavq(vecA2, vecB);
282 acc3 = vmlaldavq(vecA3, vecB);
283
284 acc0 = asrl(acc0, 15);
285 acc1 = asrl(acc1, 15);
286 acc2 = asrl(acc2, 15);
287 acc3 = asrl(acc3, 15);
288
289 pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
290 pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
291 pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
292 pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
293
294 pOut++;
295
296 /* move to next B column */
297 pInB = pInB + 1;
298
299 vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
300
301 acc0 = vmlaldavq(vecA0, vecB);
302 acc1 = vmlaldavq(vecA1, vecB);
303 acc2 = vmlaldavq(vecA2, vecB);
304 acc3 = vmlaldavq(vecA3, vecB);
305
306 acc0 = asrl(acc0, 15);
307 acc1 = asrl(acc1, 15);
308 acc2 = asrl(acc2, 15);
309 acc3 = asrl(acc3, 15);
310
311 pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
312 pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
313 pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
314 pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
315 /*
316 * Return to application
317 */
318 return (ARM_MATH_SUCCESS);
319 }
320
321
arm_mat_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)322 arm_status arm_mat_mult_q15(
323 const arm_matrix_instance_q15 * pSrcA,
324 const arm_matrix_instance_q15 * pSrcB,
325 arm_matrix_instance_q15 * pDst,
326 q15_t * pState)
327 {
328 q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
329 q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
330 q15_t *pInA2;
331 q15_t *pInB2;
332 q15_t *px; /* Temporary output data matrix pointer */
333 q15_t *px2; /* Temporary output data matrix pointer */
334 uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
335 uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
336 uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
337 uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
338 uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
339 q15_t *pSrcBT = pState; /* input data matrix pointer for transpose */
340 uint32_t blkCnt; /* loop counters */
341 arm_status status; /* Status of matrix multiplication */
342 arm_matrix_instance_q15 BT;
343
344 #ifdef ARM_MATH_MATRIX_CHECK
345
346 /* Check for matrix mismatch condition */
347 if ((pSrcA->numCols != pSrcB->numRows) ||
348 (pSrcA->numRows != pDst->numRows) ||
349 (pSrcB->numCols != pDst->numCols) )
350 {
351 /* Set status as ARM_MATH_SIZE_MISMATCH */
352 status = ARM_MATH_SIZE_MISMATCH;
353 }
354 else
355 #endif
356 {
357 /* small squared matrix specialized routines */
358 if (numRowsA == numColsB && numColsB == numColsA) {
359
360 if (numRowsA == 1) {
361 q63_t sum;
362 sum = pInA[0] * pInB[0];
363 pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16);
364 return (ARM_MATH_SUCCESS);
365 } else if (numRowsA == 2)
366 return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
367 else if (numRowsA == 3)
368 return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
369 else if (numRowsA == 4)
370 return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
371 }
372
373 /*
374 * Matrix transpose
375 */
376
377 BT.numRows = numColsB;
378 BT.numCols = numRowsB;
379 BT.pData = pSrcBT;
380
381 arm_mat_trans_q15(pSrcB, &BT);
382
383
384 /*
385 * Reset the variables for the usage in the following multiplication process
386 */
387 i = 0;
388 row = numRowsA >> 1;
389 px = pDst->pData;
390 px2 = px + numColsB;
391
392 /*
393 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
394 */
395
396 /*
397 * row loop
398 */
399 while (row > 0u) {
400 /*
401 * For every row wise process, the column loop counter is to be initiated
402 */
403 col = numColsB >> 1;
404 /*
405 * For every row wise process, the pIn2 pointer is set
406 * to the starting address of the transposed pSrcB data
407 */
408 pInB = pSrcBT;
409 pInB2 = pInB + numRowsB;
410 j = 0;
411
412 /*
413 * column loop
414 */
415 while (col > 0u) {
416 q15_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
417 q15x8_t vecA, vecA2, vecB, vecB2;
418 q63_t acc0, acc1, acc2, acc3;
419
420 /*
421 * Initiate the pointer pIn1 to point to the starting address of the column being processed
422 */
423 pInA = pSrcA->pData + i;
424 pInA2 = pInA + numColsA;
425 pInB = pSrcBT + j;
426 pInB2 = pInB + numRowsB;
427
428
429 pSrcAVec = (q15_t const *) pInA;
430 pSrcA2Vec = (q15_t const *) pInA2;
431 pSrcBVec = (q15_t const *) pInB;
432 pSrcB2Vec = (q15_t const *) pInB2;
433
434 acc0 = 0LL;
435 acc1 = 0LL;
436 acc2 = 0LL;
437 acc3 = 0LL;
438
439 vecA = vld1q(pSrcAVec);
440 pSrcAVec += 8;
441
442 blkCnt = numColsA / 8;
443 while (blkCnt > 0U) {
444 vecB = vld1q(pSrcBVec);
445 pSrcBVec += 8;
446 acc0 = vmlaldavaq(acc0, vecA, vecB);
447 vecA2 = vld1q(pSrcA2Vec);
448 pSrcA2Vec += 8;
449 acc1 = vmlaldavaq(acc1, vecA2, vecB);
450 vecB2 = vld1q(pSrcB2Vec);
451 pSrcB2Vec += 8;
452 acc2 = vmlaldavaq(acc2, vecA, vecB2);
453 vecA = vld1q(pSrcAVec);
454 pSrcAVec += 8;
455 acc3 = vmlaldavaq(acc3, vecA2, vecB2);
456
457 blkCnt--;
458 }
459 /*
460 * tail
461 */
462 blkCnt = numColsA & 7;
463 if (blkCnt > 0U) {
464 mve_pred16_t p0 = vctp16q(blkCnt);
465 vecB = vld1q(pSrcBVec);
466 acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
467 vecA2 = vld1q(pSrcA2Vec);
468 acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0);
469 vecB2 = vld1q(pSrcB2Vec);
470 acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0);
471 vecA = vld1q(pSrcAVec);
472 acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0);
473 }
474
475 *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
476 *px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15);
477 *px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15);
478 *px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15);
479 j += numRowsB * 2;
480 /*
481 * Decrement the column loop counter
482 */
483 col--;
484
485 }
486
487 i = i + numColsA * 2;
488 px = px2 + (numColsB & 1u);
489 px2 = px + numColsB;
490 /*
491 * Decrement the row loop counter
492 */
493 row--;
494 }
495
496 /*
497 * Compute remaining row and/or column below
498 */
499
500 if (numColsB & 1u) {
501 row = numRowsA & (~0x1); //avoid redundant computation
502 px = pDst->pData + numColsB - 1;
503 i = 0;
504
505 /*
506 * row loop
507 */
508 while (row > 0) {
509 q15_t const *pSrcAVec, *pSrcBVec;
510 q15x8_t vecA, vecB;
511 q63_t acc0;
512
513 /*
514 * point to last column in matrix B
515 */
516 pInB = pSrcBT + numRowsB * (numColsB - 1);
517 pInA = pSrcA->pData + i;
518
519 pSrcAVec = (q15_t const *) pInA;
520 pSrcBVec = (q15_t const *) pInB;
521
522 acc0 = 0LL;
523 blkCnt = (numColsA) / 8;
524 while (blkCnt > 0U) {
525 vecA = vld1q(pSrcAVec);
526 pSrcAVec += 8;
527 vecB = vld1q(pSrcBVec);
528 pSrcBVec += 8;
529 acc0 = vmlaldavaq(acc0, vecA, vecB);
530
531 blkCnt--;
532 }
533 /*
534 * tail
535 */
536 blkCnt = (numColsA & 7);
537 if (blkCnt > 0U) {
538 mve_pred16_t p0 = vctp16q(blkCnt);
539 vecA = vld1q(pSrcAVec);
540 vecB = vld1q(pSrcBVec);
541 acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
542 }
543
544 *px = (q15_t) MVE_ASRL_SAT16(acc0, 15);
545
546 px += numColsB;
547
548 i += numColsA;
549 /*
550 * Decrement the row loop counter
551 */
552 row--;
553 }
554 }
555
556 if (numRowsA & 1u) {
557 col = numColsB;
558 i = 0u;
559 /*
560 * point to last row in output matrix
561 */
562 px = pDst->pData + (numColsB) * (numRowsA - 1);
563 /*
564 * col loop
565 */
566 while (col > 0) {
567 q15_t const *pSrcAVec, *pSrcBVec;
568 q15x8_t vecA, vecB;
569 q63_t acc0;
570
571 /*
572 * point to last row in matrix A
573 */
574 pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
575 pInB = pSrcBT + i;
576
577 /*
578 * Set the variable sum, that acts as accumulator, to zero
579 */
580 pSrcAVec = (q15_t const *) pInA;
581 pSrcBVec = (q15_t const *) pInB;
582 acc0 = 0LL;
583
584 blkCnt = ((numColsA) / 8);
585 while (blkCnt > 0U) {
586 vecA = vld1q(pSrcAVec);
587 pSrcAVec += 8;
588 vecB = vld1q(pSrcBVec);
589 pSrcBVec += 8;
590 acc0 = vmlaldavaq(acc0, vecA, vecB);
591
592 blkCnt--;
593 }
594 /*
595 * tail
596 */
597 blkCnt = (numColsA & 7);
598 if (blkCnt > 0U) {
599 mve_pred16_t p0 = vctp16q(blkCnt);
600 vecA = vld1q(pSrcAVec);
601 vecB = vld1q(pSrcBVec);
602 acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
603 }
604
605 *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
606
607 i += numColsA;
608
609 /*
610 * Decrement the col loop counter
611 */
612 col--;
613 }
614 }
615
616 /* Set status as ARM_MATH_SUCCESS */
617 status = ARM_MATH_SUCCESS;
618 }
619 /* Return to application */
620 return (status);
621 }
622
623 #else
arm_mat_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)624 arm_status arm_mat_mult_q15(
625 const arm_matrix_instance_q15 * pSrcA,
626 const arm_matrix_instance_q15 * pSrcB,
627 arm_matrix_instance_q15 * pDst,
628 q15_t * pState)
629 {
630 q63_t sum; /* Accumulator */
631
632 #if defined (ARM_MATH_DSP) /* != CM0 */
633
634 q15_t *pSrcBT = pState; /* Input data matrix pointer for transpose */
635 q15_t *pInA = pSrcA->pData; /* Input data matrix pointer A of Q15 type */
636 q15_t *pInB = pSrcB->pData; /* Input data matrix pointer B of Q15 type */
637 q15_t *px; /* Temporary output data matrix pointer */
638 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
639 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
640 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
641 uint16_t numRowsB = pSrcB->numRows; /* Number of rows of input matrix B */
642 uint32_t col, i = 0U, row = numRowsB, colCnt; /* Loop counters */
643 arm_status status; /* Status of matrix multiplication */
644
645 q31_t inA1, inB1, inA2, inB2;
646 arm_matrix_instance_q15 BT;
647
648 #ifdef ARM_MATH_MATRIX_CHECK
649
650 /* Check for matrix mismatch condition */
651 if ((pSrcA->numCols != pSrcB->numRows) ||
652 (pSrcA->numRows != pDst->numRows) ||
653 (pSrcB->numCols != pDst->numCols) )
654 {
655 /* Set status as ARM_MATH_SIZE_MISMATCH */
656 status = ARM_MATH_SIZE_MISMATCH;
657 }
658 else
659
660 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
661 {
662
663 BT.numRows = numColsB;
664 BT.numCols = numRowsB;
665 BT.pData = pSrcBT;
666
667 arm_mat_trans_q15(pSrcB,&BT);
668 /* Reset variables for usage in following multiplication process */
669 row = numRowsA;
670 i = 0U;
671 px = pDst->pData;
672
673 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
674 /* row loop */
675 do
676 {
677 /* For every row wise process, column loop counter is to be initiated */
678 col = numColsB;
679
680 /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
681 pInB = pSrcBT;
682
683 /* column loop */
684 do
685 {
686 /* Set variable sum, that acts as accumulator, to zero */
687 sum = 0;
688
689 /* Initiate pointer pInA to point to starting address of column being processed */
690 pInA = pSrcA->pData + i;
691
692 /* Apply loop unrolling and compute 2 MACs simultaneously. */
693 colCnt = numColsA >> 2U;
694
695 /* matrix multiplication */
696 while (colCnt > 0U)
697 {
698 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
699
700 /* read real and imag values from pSrcA and pSrcB buffer */
701 inA1 = read_q15x2_ia (&pInA);
702 inB1 = read_q15x2_ia (&pInB);
703
704 inA2 = read_q15x2_ia (&pInA);
705 inB2 = read_q15x2_ia (&pInB);
706
707 /* Multiply and Accumulates */
708 sum = __SMLALD(inA1, inB1, sum);
709 sum = __SMLALD(inA2, inB2, sum);
710
711 /* Decrement loop counter */
712 colCnt--;
713 }
714
715 /* process remaining column samples */
716 colCnt = numColsA % 0x4U;
717
718 while (colCnt > 0U)
719 {
720 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
721 sum += *pInA++ * *pInB++;
722
723 /* Decrement loop counter */
724 colCnt--;
725 }
726
727 /* Saturate and store result in destination buffer */
728 *px = (q15_t) (__SSAT((sum >> 15), 16));
729 px++;
730
731 /* Decrement column loop counter */
732 col--;
733
734 } while (col > 0U);
735
736 i = i + numColsA;
737
738 /* Decrement row loop counter */
739 row--;
740
741 } while (row > 0U);
742
743 #else /* #if defined (ARM_MATH_DSP) */
744
745 q15_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
746 q15_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
747 q15_t *pInA = pSrcA->pData; /* Input data matrix pointer A of Q15 type */
748 q15_t *pInB = pSrcB->pData; /* Input data matrix pointer B of Q15 type */
749 q15_t *pOut = pDst->pData; /* Output data matrix pointer */
750 q15_t *px; /* Temporary output data matrix pointer */
751 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
752 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
753 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
754 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
755 arm_status status; /* Status of matrix multiplication */
756 (void)pState;
757
758 #ifdef ARM_MATH_MATRIX_CHECK
759
760 /* Check for matrix mismatch condition */
761 if ((pSrcA->numCols != pSrcB->numRows) ||
762 (pSrcA->numRows != pDst->numRows) ||
763 (pSrcB->numCols != pDst->numCols) )
764 {
765 /* Set status as ARM_MATH_SIZE_MISMATCH */
766 status = ARM_MATH_SIZE_MISMATCH;
767 }
768 else
769
770 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
771
772 {
773 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
774 /* row loop */
775 do
776 {
777 /* Output pointer is set to starting address of the row being processed */
778 px = pOut + i;
779
780 /* For every row wise process, column loop counter is to be initiated */
781 col = numColsB;
782
783 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
784 pIn2 = pSrcB->pData;
785
786 /* column loop */
787 do
788 {
789 /* Set the variable sum, that acts as accumulator, to zero */
790 sum = 0;
791
792 /* Initiate pointer pIn1 to point to starting address of pSrcA */
793 pIn1 = pInA;
794
795 /* Matrix A columns number of MAC operations are to be performed */
796 colCnt = numColsA;
797
798 /* matrix multiplication */
799 while (colCnt > 0U)
800 {
801 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
802
803 /* Perform multiply-accumulates */
804 sum += (q31_t) * pIn1++ * *pIn2;
805 pIn2 += numColsB;
806
807 /* Decrement loop counter */
808 colCnt--;
809 }
810
811 /* Convert result from 34.30 to 1.15 format and store saturated value in destination buffer */
812
813 /* Saturate and store result in destination buffer */
814 *px++ = (q15_t) __SSAT((sum >> 15), 16);
815
816 /* Decrement column loop counter */
817 col--;
818
819 /* Update pointer pIn2 to point to starting address of next column */
820 pIn2 = pInB + (numColsB - col);
821
822 } while (col > 0U);
823
824 /* Update pointer pSrcA to point to starting address of next row */
825 i = i + numColsB;
826 pInA = pInA + numColsA;
827
828 /* Decrement row loop counter */
829 row--;
830
831 } while (row > 0U);
832
833 #endif /* #if defined (ARM_MATH_DSP) */
834
835 /* Set status as ARM_MATH_SUCCESS */
836 status = ARM_MATH_SUCCESS;
837 }
838
839 /* Return to application */
840 return (status);
841 }
842 #endif /* defined(ARM_MATH_MVEI) */
843
844 /**
845 @} end of MatrixMult group
846 */
847