1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_opt_q31.c
4  * Description:  Q31 matrix multiplication
5  *
6  * $Date:        3 Nov 2021
7  * $Revision:    V1.10.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup MatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q31 matrix multiplication.
42   @param[in]     pSrcA      points to the first input matrix structure
43   @param[in]     pSrcB      points to the second input matrix structure
44   @param[out]    pDst       points to output matrix structure
45   @param[in]  pState  points to the array for storing intermediate results
46   @return        execution status
47                    - \ref ARM_MATH_SUCCESS       : Operation successful
48                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49 
50   @par           Scaling and Overflow Behavior
51                    The function is implemented using an internal 64-bit accumulator.
52                    The accumulator has a 2.62 format and maintains full precision of the intermediate
53                    multiplication results but provides only a single guard bit. There is no saturation
54                    on intermediate additions. Thus, if the accumulator overflows it wraps around and
55                    distorts the result. The input signals should be scaled down to avoid intermediate
56                    overflows. The input is thus scaled down by log2(numColsA) bits
57                    to avoid overflows, as a total of numColsA additions are performed internally.
58                    The 2.62 accumulator is right shifted by 31 bits and saturated to 1.31 format to yield the final result.
59   @remark
60                    Refer to \ref arm_mat_mult_fast_q31() for a faster but less precise implementation of this function.
61   @remark
62                    This function is a faster implementation of arm_mat_mult_q31 for MVE but it is requiring
63                    additional storage for intermediate results.
64  */
65 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
66 
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70 
arm_mat_mult_opt_q31_2x2_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)71 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_2x2_mve(
72     const arm_matrix_instance_q31 * pSrcA,
73     const arm_matrix_instance_q31 * pSrcB,
74     arm_matrix_instance_q31 * pDst)
75 {
76     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
77     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
78     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
79     uint32x4_t   vecColBOffs;
80     q31_t       *pInA0 = pInA;
81     q31_t       *pInA1 = pInA0 + MATRIX_DIM2;
82     q63_t        acc0, acc1;
83     q31x4_t      vecB, vecA0, vecA1;
84     /* enable predication to disable half of vector elements */
85     mve_pred16_t p0 = vctp32q(MATRIX_DIM2);
86 
87     vecColBOffs = vidupq_u32((uint32_t)0, 1);
88     vecColBOffs = vecColBOffs * MATRIX_DIM2;
89 
90     pInB = pSrcB->pData;
91 
92     /* load 1st B column (partial load) */
93     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
94 
95     /* load A rows */
96     vecA0 = vldrwq_s32(pInA0);
97     vecA1 = vldrwq_s32(pInA1);
98 
99     acc0 = vrmlaldavhq(vecA0, vecB);
100     acc1 = vrmlaldavhq(vecA1, vecB);
101 
102     acc0 = asrl(acc0, 23);
103     acc1 = asrl(acc1, 23);
104 
105     pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
106     pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
107     pOut++;
108 
109     /* move to next B column */
110     pInB = pInB + 1;
111 
112     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
113 
114     acc0 = vrmlaldavhq(vecA0, vecB);
115     acc1 = vrmlaldavhq(vecA1, vecB);
116 
117     acc0 = asrl(acc0, 23);
118     acc1 = asrl(acc1, 23);
119 
120     pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
121     pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
122     /*
123      * Return to application
124      */
125     return (ARM_MATH_SUCCESS);
126 }
127 
128 
129 
arm_mat_mult_opt_q31_3x3_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)130 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_3x3_mve(
131     const arm_matrix_instance_q31 * pSrcA,
132     const arm_matrix_instance_q31 * pSrcB,
133     arm_matrix_instance_q31 * pDst)
134 {
135     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
136     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
137     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
138     uint32x4_t   vecColBOffs;
139     q31_t       *pInA0 = pInA;
140     q31_t       *pInA1 = pInA0 + MATRIX_DIM3;
141     q31_t       *pInA2 = pInA1 + MATRIX_DIM3;
142     q63_t        acc0, acc1, acc2;
143     q31x4_t      vecB, vecA;
144     /* enable predication to disable last (4th) vector element */
145     mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
146 
147     vecColBOffs = vidupq_u32((uint32_t)0, 1);
148     vecColBOffs = vecColBOffs * MATRIX_DIM3;
149 
150     pInB = pSrcB->pData;
151 
152     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
153 
154     vecA = vldrwq_s32(pInA0);
155     acc0 = vrmlaldavhq(vecA, vecB);
156     vecA = vldrwq_s32(pInA1);
157     acc1 = vrmlaldavhq(vecA, vecB);
158     vecA = vldrwq_s32(pInA2);
159     acc2 = vrmlaldavhq(vecA, vecB);
160 
161     acc0 = asrl(acc0, 23);
162     acc1 = asrl(acc1, 23);
163     acc2 = asrl(acc2, 23);
164 
165     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
166     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
167     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
168     pOut++;
169 
170     /* move to next B column */
171     pInB = pInB + 1;
172 
173     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
174 
175     vecA = vldrwq_s32(pInA0);
176     acc0 = vrmlaldavhq(vecA, vecB);
177     vecA = vldrwq_s32(pInA1);
178     acc1 = vrmlaldavhq(vecA, vecB);
179     vecA = vldrwq_s32(pInA2);
180     acc2 = vrmlaldavhq(vecA, vecB);
181 
182     acc0 = asrl(acc0, 23);
183     acc1 = asrl(acc1, 23);
184     acc2 = asrl(acc2, 23);
185 
186     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
187     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
188     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
189     pOut++;
190 
191     /* move to next B column */
192     pInB = pInB + 1;
193 
194     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
195 
196     vecA = vldrwq_s32(pInA0);
197     acc0 = vrmlaldavhq(vecA, vecB);
198     vecA = vldrwq_s32(pInA1);
199     acc1 = vrmlaldavhq(vecA, vecB);
200     vecA = vldrwq_s32(pInA2);
201     acc2 = vrmlaldavhq(vecA, vecB);
202 
203     acc0 = asrl(acc0, 23);
204     acc1 = asrl(acc1, 23);
205     acc2 = asrl(acc2, 23);
206 
207     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
208     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
209     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
210     /*
211      * Return to application
212      */
213     return (ARM_MATH_SUCCESS);
214 }
215 
arm_mat_mult_opt_q31_4x4_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)216 __STATIC_INLINE arm_status arm_mat_mult_opt_q31_4x4_mve(
217     const arm_matrix_instance_q31 * pSrcA,
218     const arm_matrix_instance_q31 * pSrcB,
219     arm_matrix_instance_q31 * pDst)
220 {
221     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
222     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
223     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
224     uint32x4_t   vecColBOffs;
225     q31_t       *pInA0 = pInA;
226     q31_t       *pInA1 = pInA0 + MATRIX_DIM4;
227     q31_t       *pInA2 = pInA1 + MATRIX_DIM4;
228     q31_t       *pInA3 = pInA2 + MATRIX_DIM4;
229     q63_t        acc0, acc1, acc2, acc3;
230     q31x4_t      vecB, vecA;
231 
232     vecColBOffs = vidupq_u32((uint32_t)0, 4);
233 
234     pInB = pSrcB->pData;
235 
236     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
237 
238     vecA = vldrwq_s32(pInA0);
239     acc0 = vrmlaldavhq(vecA, vecB);
240     vecA = vldrwq_s32(pInA1);
241     acc1 = vrmlaldavhq(vecA, vecB);
242     vecA = vldrwq_s32(pInA2);
243     acc2 = vrmlaldavhq(vecA, vecB);
244     vecA = vldrwq_s32(pInA3);
245     acc3 = vrmlaldavhq(vecA, vecB);
246 
247     acc0 = asrl(acc0, 23);
248     acc1 = asrl(acc1, 23);
249     acc2 = asrl(acc2, 23);
250     acc3 = asrl(acc3, 23);
251 
252     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
253     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
254     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
255     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
256     pOut++;
257 
258     /* move to next B column */
259     pInB = pInB + 1;
260 
261     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
262 
263     vecA = vldrwq_s32(pInA0);
264     acc0 = vrmlaldavhq(vecA, vecB);
265     vecA = vldrwq_s32(pInA1);
266     acc1 = vrmlaldavhq(vecA, vecB);
267     vecA = vldrwq_s32(pInA2);
268     acc2 = vrmlaldavhq(vecA, vecB);
269     vecA = vldrwq_s32(pInA3);
270     acc3 = vrmlaldavhq(vecA, vecB);
271 
272     acc0 = asrl(acc0, 23);
273     acc1 = asrl(acc1, 23);
274     acc2 = asrl(acc2, 23);
275     acc3 = asrl(acc3, 23);
276 
277     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
278     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
279     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
280     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
281 
282     pOut++;
283 
284     /* move to next B column */
285     pInB = pInB + 1;
286 
287     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
288 
289     vecA = vldrwq_s32(pInA0);
290     acc0 = vrmlaldavhq(vecA, vecB);
291     vecA = vldrwq_s32(pInA1);
292     acc1 = vrmlaldavhq(vecA, vecB);
293     vecA = vldrwq_s32(pInA2);
294     acc2 = vrmlaldavhq(vecA, vecB);
295     vecA = vldrwq_s32(pInA3);
296     acc3 = vrmlaldavhq(vecA, vecB);
297 
298     acc0 = asrl(acc0, 23);
299     acc1 = asrl(acc1, 23);
300     acc2 = asrl(acc2, 23);
301     acc3 = asrl(acc3, 23);
302 
303     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
304     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
305     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
306     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
307 
308     pOut++;
309 
310     /* move to next B column */
311     pInB = pInB + 1;
312 
313     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
314 
315     vecA = vldrwq_s32(pInA0);
316     acc0 = vrmlaldavhq(vecA, vecB);
317     vecA = vldrwq_s32(pInA1);
318     acc1 = vrmlaldavhq(vecA, vecB);
319     vecA = vldrwq_s32(pInA2);
320     acc2 = vrmlaldavhq(vecA, vecB);
321     vecA = vldrwq_s32(pInA3);
322     acc3 = vrmlaldavhq(vecA, vecB);
323 
324     acc0 = asrl(acc0, 23);
325     acc1 = asrl(acc1, 23);
326     acc2 = asrl(acc2, 23);
327     acc3 = asrl(acc3, 23);
328 
329     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
330     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
331     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
332     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
333     /*
334      * Return to application
335      */
336     return (ARM_MATH_SUCCESS);
337 }
338 
339 
arm_mat_mult_opt_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst,q31_t * pState)340 arm_status arm_mat_mult_opt_q31(
341     const arm_matrix_instance_q31 * pSrcA,
342     const arm_matrix_instance_q31 * pSrcB,
343     arm_matrix_instance_q31 * pDst,
344     q31_t *pState)
345 {
346     q31_t          *pInA = pSrcA->pData;        /* input data matrix pointer A */
347     q31_t          *pInB = pSrcB->pData;        /* input data matrix pointer B */
348     q31_t          *pInA2;
349     q31_t          *pInB2;
350     q31_t          *px;         /* Temporary output data matrix pointer */
351     q31_t          *px2;        /* Temporary output data matrix pointer */
352     uint32_t        numRowsA = pSrcA->numRows;  /* number of rows of input matrix A    */
353     uint32_t        numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
354     uint32_t        numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
355     uint32_t        numRowsB = pSrcB->numRows;  /* number of rows of input matrix A    */
356     uint32_t        col, i = 0u, j, row = numRowsB;     /* loop counters */
357     q31_t          *pSrcBT = pState;     /* input data matrix pointer for transpose */
358     uint32_t        blkCnt;     /* loop counters */
359     arm_status      status;                            /* Status of matrix multiplication */
360     arm_matrix_instance_q31 BT;
361 #ifdef ARM_MATH_MATRIX_CHECK
362 
363     /* Check for matrix mismatch condition */
364     if ((pSrcA->numCols != pSrcB->numRows) ||
365         (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) {
366         /* Set status as ARM_MATH_SIZE_MISMATCH */
367         status = ARM_MATH_SIZE_MISMATCH;
368     } else
369 #endif                          /* #ifdef ARM_MATH_MATRIX_CHECK */
370     {
371 
372          /* small squared matrix specialized routines */
373     if(numRowsA == numColsB && numColsB == numColsA) {
374         if (numRowsA == 1)
375         {
376           q63_t sum =  (q63_t) *pInA * *pInB;
377           pDst->pData[0] = (q31_t)(sum >> 31);
378           return (ARM_MATH_SUCCESS);
379         }
380         else if(numRowsA == 2)
381             return arm_mat_mult_opt_q31_2x2_mve(pSrcA, pSrcB, pDst);
382         else if(numRowsA == 3)
383             return arm_mat_mult_opt_q31_3x3_mve(pSrcA, pSrcB, pDst);
384         else if (numRowsA == 4)
385             return arm_mat_mult_opt_q31_4x4_mve(pSrcA, pSrcB, pDst);
386     }
387 
388 
389         /*
390          * Matrix transpose
391          */
392         BT.numRows = numColsB;
393         BT.numCols = numRowsB;
394         BT.pData = pSrcBT;
395 
396         arm_mat_trans_q31(pSrcB, &BT);
397 
398 
399         /*
400          * Reset the variables for the usage in the following multiplication process
401          */
402         i = 0;
403         row = numRowsA >> 1;
404         px = pDst->pData;
405         px2 = px + numColsB;
406 
407         /*
408          * main loop
409          * compute 2 x 2 output blocks
410          * with dot products (Matrix A rows * Transposed MAtrix B rows)
411          */
412         while (row > 0u) {
413             /*
414              * For every row wise process, the column loop counter is to be initiated
415              * Compute 2 columns and 2 rows in parrallel
416              */
417             col = numColsB >> 1;
418             j = 0;
419 
420             /*
421              * column pair loop
422              */
423             while (col > 0u) {
424                 q31_t const    *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
425                 q31x4_t         vecA, vecA2, vecB, vecB2;
426                 q63_t           acc0, acc1, acc2, acc3;
427 
428                 /*
429                  * Initiate the pointers
430                  * - 2 x consecutive Matrix A rows (i increment is 2 x numColsA)
431                  * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
432                  */
433                 pInA = pSrcA->pData + i;
434                 pInA2 = pInA + numColsA;
435                 pInB = pSrcBT + j;
436                 pInB2 = pInB + numRowsB;
437 
438 
439                 pSrcAVec = (q31_t const *) pInA;
440                 pSrcA2Vec = (q31_t const *) pInA2;
441                 pSrcBVec = (q31_t const *) pInB;
442                 pSrcB2Vec = (q31_t const *) pInB2;
443 
444                 acc0 = 0LL;
445                 acc1 = 0LL;
446                 acc2 = 0LL;
447                 acc3 = 0LL;
448 
449                 /* load scheduling */
450                 vecA = vld1q(pSrcAVec);
451                 pSrcAVec += 4;
452 
453                 blkCnt = (numColsA / 4);
454                 while (blkCnt > 0U) {
455                     vecB = vld1q(pSrcBVec);
456                     pSrcBVec += 4;
457                     acc0 = vrmlaldavhaq(acc0, vecA, vecB);
458                     vecA2 = vld1q(pSrcA2Vec);
459                     pSrcA2Vec += 4;
460                     acc1 = vrmlaldavhaq(acc1, vecA2, vecB);
461                     vecB2 = vld1q(pSrcB2Vec);
462                     pSrcB2Vec += 4;
463                     acc2 = vrmlaldavhaq(acc2, vecA, vecB2);
464                     vecA = vld1q(pSrcAVec);
465                     pSrcAVec += 4;
466                     acc3 = vrmlaldavhaq(acc3, vecA2, vecB2);
467 
468                     blkCnt--;
469                 }
470                 /*
471                  * tail
472                  * (will be merged thru tail predication)
473                  */
474                 blkCnt = (numColsA & 3);
475                 if (blkCnt > 0U) {
476                     mve_pred16_t    p0 = vctp32q(blkCnt);
477                     vecB = vld1q(pSrcBVec);
478                     acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
479                     vecA2 = vld1q(pSrcA2Vec);
480                     acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0);
481                     vecB2 = vld1q(pSrcB2Vec);
482                     acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0);
483                     vecA = vld1q(pSrcAVec);
484                     acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0);
485                 }
486 
487                 /* Convert to 1.31 */
488                 acc0 = asrl(acc0, 23);
489                 acc1 = asrl(acc1, 23);
490                 acc2 = asrl(acc2, 23);
491                 acc3 = asrl(acc3, 23);
492 
493                 /* Store the results (2 x 2 block) in the destination buffer */
494                 *px++ = (q31_t) acc0;
495                 *px++ = (q31_t) acc2;
496                 *px2++ = (q31_t) acc1;
497                 *px2++ = (q31_t) acc3;
498 
499                 j += numRowsB * 2;
500                 /*
501                  * Decrement the column pair loop counter
502                  */
503                 col--;
504 
505             }
506 
507             i = i + numColsA * 2;
508             px = px2 + (numColsB & 1u);
509             px2 = px + numColsB;
510             /*
511              * Decrement the row pair loop counter
512              */
513             row--;
514         }
515 
516         /*
517          * Compute remaining row and/or column below
518          */
519         if (numColsB & 1u) {
520             row = numRowsA & (~0x1);    //avoid redundant computation
521             px = pDst->pData + numColsB - 1;
522             i = 0;
523 
524             /*
525              * row loop
526              */
527             while (row > 0) {
528                 q31_t const    *pSrcAVec, *pSrcBVec;
529                 q31x4_t         vecA, vecB;
530                 q63_t           acc0;
531 
532                 /*
533                  * point to last column in matrix B
534                  */
535                 pInB = pSrcBT + numRowsB * (numColsB - 1);
536                 pInA = pSrcA->pData + i;
537 
538                 pSrcAVec = (q31_t const *) pInA;
539                 pSrcBVec = (q31_t const *) pInB;
540 
541                 /* single dot-product */
542                 acc0 = 0LL;
543                 blkCnt = (numColsA / 4);
544                 while (blkCnt > 0U) {
545                     vecA = vld1q(pSrcAVec);
546                     pSrcAVec += 4;
547                     vecB = vld1q(pSrcBVec);
548                     pSrcBVec += 4;
549                     acc0 = vrmlaldavhaq(acc0, vecA, vecB);
550 
551                     blkCnt--;
552                 }
553                 /*
554                  * tail
555                  * (will be merged thru tail predication)
556                  */
557                 blkCnt = (numColsA & 3);
558                 if (blkCnt > 0U) {
559                     mve_pred16_t    p0 = vctp32q(blkCnt);
560                     vecA = vld1q(pSrcAVec);
561                     vecB = vld1q(pSrcBVec);
562                     acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
563                 }
564 
565                 acc0 = asrl(acc0, 23);
566                 *px = (q31_t) acc0;
567 
568                 px += numColsB;
569 
570                 i += numColsA;
571                 /*
572                  * Decrement the row loop counter
573                  */
574                 row--;
575             }
576         }
577 
578         if (numRowsA & 1u) {
579             col = numColsB;
580             i = 0u;
581             /*
582              * point to last row in output matrix
583              */
584             px = pDst->pData + (numColsB) * (numRowsA - 1);
585             /*
586              * col loop
587              */
588             while (col > 0) {
589                 q31_t const    *pSrcAVec, *pSrcBVec;
590                 q31x4_t         vecA, vecB;
591                 q63_t           acc0;
592 
593                 /*
594                  * point to last row in matrix A
595                  */
596                 pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
597                 pInB = pSrcBT + i;
598 
599                 /*
600                  * Set the variable sum, that acts as accumulator, to zero
601                  */
602                 pSrcAVec = (q31_t const *) pInA;
603                 pSrcBVec = (q31_t const *) pInB;
604                 acc0 = 0LL;
605 
606                 blkCnt = (numColsA / 4);
607                 while (blkCnt > 0U) {
608                     vecA = vld1q(pSrcAVec);
609                     pSrcAVec += 4;
610                     vecB = vld1q(pSrcBVec);
611                     pSrcBVec += 4;
612                     acc0 = vrmlaldavhaq(acc0, vecA, vecB);
613 
614                     blkCnt--;
615                 }
616                 /*
617                  * tail
618                  * (will be merged thru tail predication)
619                  */
620                 blkCnt = (numColsA & 3);
621                 if (blkCnt > 0U) {
622                     mve_pred16_t    p0 = vctp32q(blkCnt);
623                     vecA = vld1q(pSrcAVec);
624                     vecB = vld1q(pSrcBVec);
625                     acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
626                 }
627 
628                 acc0 = asrl(acc0, 23);
629                 *px++ = (q31_t) acc0;
630 
631                 i += numColsA;
632                 /*
633                  * Decrement the col loop counter
634                  */
635                 col--;
636             }
637         }
638         /* Set status as ARM_MATH_SUCCESS */
639         status = ARM_MATH_SUCCESS;
640     }
641     /*
642      * Return to application
643      */
644     return (status);
645 }
646 
647 #else
arm_mat_mult_opt_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst,q31_t * pState)648 arm_status arm_mat_mult_opt_q31(
649   const arm_matrix_instance_q31 * pSrcA,
650   const arm_matrix_instance_q31 * pSrcB,
651         arm_matrix_instance_q31 * pDst,
652         q31_t *pState)
653 {
654   q31_t *pIn1 = pSrcA->pData;                    /* Input data matrix pointer A */
655   q31_t *pIn2 = pSrcB->pData;                    /* Input data matrix pointer B */
656   q31_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A */
657   q31_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B */
658   q31_t *pOut = pDst->pData;                     /* Output data matrix pointer */
659   q31_t *px;                                     /* Temporary output data matrix pointer */
660   q63_t sum;                                     /* Accumulator */
661   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
662   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
663   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
664   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
665   arm_status status;                             /* Status of matrix multiplication */
666   (void)pState;
667 #ifdef ARM_MATH_MATRIX_CHECK
668 
669   /* Check for matrix mismatch condition */
670   if ((pSrcA->numCols != pSrcB->numRows) ||
671       (pSrcA->numRows != pDst->numRows)  ||
672       (pSrcB->numCols != pDst->numCols)    )
673   {
674     /* Set status as ARM_MATH_SIZE_MISMATCH */
675     status = ARM_MATH_SIZE_MISMATCH;
676   }
677   else
678 
679 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
680 
681   {
682     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
683     /* row loop */
684     do
685     {
686       /* Output pointer is set to starting address of row being processed */
687       px = pOut + i;
688 
689       /* For every row wise process, column loop counter is to be initiated */
690       col = numColsB;
691 
692       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
693       pIn2 = pSrcB->pData;
694 
695       /* column loop */
696       do
697       {
698         /* Set the variable sum, that acts as accumulator, to zero */
699         sum = 0;
700 
701         /* Initialize pointer pIn1 to point to starting address of column being processed */
702         pIn1 = pInA;
703 
704 #if defined (ARM_MATH_LOOPUNROLL)
705 
706         /* Loop unrolling: Compute 4 MACs at a time. */
707         colCnt = numColsA >> 2U;
708 
709         /* matrix multiplication */
710         while (colCnt > 0U)
711         {
712           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
713 
714           /* Perform the multiply-accumulates */
715           sum += (q63_t) *pIn1++ * *pIn2;
716           pIn2 += numColsB;
717 
718           sum += (q63_t) *pIn1++ * *pIn2;
719           pIn2 += numColsB;
720 
721           sum += (q63_t) *pIn1++ * *pIn2;
722           pIn2 += numColsB;
723 
724           sum += (q63_t) *pIn1++ * *pIn2;
725           pIn2 += numColsB;
726 
727           /* Decrement loop counter */
728           colCnt--;
729         }
730 
731         /* Loop unrolling: Compute remaining MACs */
732         colCnt = numColsA % 0x4U;
733 
734 #else
735 
736         /* Initialize cntCnt with number of columns */
737         colCnt = numColsA;
738 
739 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
740 
741         while (colCnt > 0U)
742         {
743           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
744 
745           /* Perform the multiply-accumulates */
746           sum += (q63_t) *pIn1++ * *pIn2;
747           pIn2 += numColsB;
748 
749           /* Decrement loop counter */
750           colCnt--;
751         }
752 
753         /* Convert result from 2.62 to 1.31 format and store in destination buffer */
754         *px++ = (q31_t) (sum >> 31);
755 
756         /* Decrement column loop counter */
757         col--;
758 
759         /* Update pointer pIn2 to point to starting address of next column */
760         pIn2 = pInB + (numColsB - col);
761 
762       } while (col > 0U);
763 
764       /* Update pointer pInA to point to starting address of next row */
765       i = i + numColsB;
766       pInA = pInA + numColsA;
767 
768       /* Decrement row loop counter */
769       row--;
770 
771     } while (row > 0U);
772 
773     /* Set status as ARM_MATH_SUCCESS */
774     status = ARM_MATH_SUCCESS;
775   }
776 
777   /* Return to application */
778   return (status);
779 }
780 #endif /* defined(ARM_MATH_MVEI) */
781 
782 /**
783   @} end of MatrixMult group
784  */
785