1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_q15.c
4  * Description:  Q15 matrix multiplication
5  *
6  * $Date:        3 Nov 2021
7  * $Revision:    V1.10.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup MatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q15 matrix multiplication.
42   @param[in]     pSrcA      points to the first input matrix structure
43   @param[in]     pSrcB      points to the second input matrix structure
44   @param[out]    pDst       points to output matrix structure
45   @param[in]     pState     points to the array for storing intermediate results
46   @return        execution status
47                    - \ref ARM_MATH_SUCCESS       : Operation successful
48                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49 
50   @par           Scaling and Overflow Behavior
51                    The function is implemented using an internal 64-bit accumulator. The inputs to the
52                    multiplications are in 1.15 format and multiplications yield a 2.30 result.
53                    The 2.30 intermediate results are accumulated in a 64-bit accumulator in 34.30 format.
54                    This approach provides 33 guard bits and there is no risk of overflow.
55                    The 34.30 result is then truncated to 34.15 format by discarding the low 15 bits
56                    and then saturated to 1.15 format.
57   @par
58                    Refer to \ref arm_mat_mult_fast_q15() for a faster but less precise version of this function.
59 
60   @par             pState
61                    pState will contain the transpose of pSrcB
62  */
63 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
64 
65 #define MVE_ASRL_SAT16(acc, shift)          ((sqrshrl_sat48(acc, -(32-shift)) >> 32) & 0xffffffff)
66 
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70 
arm_mat_mult_q15_2x2_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)71 __STATIC_INLINE arm_status arm_mat_mult_q15_2x2_mve(
72     const arm_matrix_instance_q15 * pSrcA,
73     const arm_matrix_instance_q15 * pSrcB,
74     arm_matrix_instance_q15 * pDst)
75 {
76     q15_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
77     q15_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
78     q15_t       *pOut = pDst->pData;   /* output data matrix pointer */
79     uint16x8_t  vecColBOffs;
80     q15_t       *pInA0 = pInA;
81     q15_t       *pInA1 = pInA0 + MATRIX_DIM2;
82     q63_t        acc0, acc1;
83     q15x8_t     vecB, vecA0, vecA1;
84     mve_pred16_t p0 = vctp16q(MATRIX_DIM2);
85 
86     vecColBOffs = vidupq_u16((uint32_t)0, 2); /* MATRIX_DIM2 */
87 
88     pInB = pSrcB->pData;
89 
90     vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
91 
92     vecA0 = vldrhq_s16(pInA0);
93     vecA1 = vldrhq_s16(pInA1);
94 
95     acc0 = vmlaldavq(vecA0, vecB);
96     acc1 = vmlaldavq(vecA1, vecB);
97 
98     acc0 = asrl(acc0, 15);
99     acc1 = asrl(acc1, 15);
100 
101     pOut[0 * MATRIX_DIM2] = (q15_t) __SSAT(acc0, 16);
102     pOut[1 * MATRIX_DIM2] = (q15_t) __SSAT(acc1, 16);
103     pOut++;
104 
105     /* move to next B column */
106     pInB = pInB + 1;
107 
108     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
109 
110     acc0 = vmlaldavq(vecA0, vecB);
111     acc1 = vmlaldavq(vecA1, vecB);
112 
113     acc0 = asrl(acc0, 15);
114     acc1 = asrl(acc1, 15);
115 
116     pOut[0 * MATRIX_DIM2] = (q15_t) __SSAT(acc0, 16);
117     pOut[1 * MATRIX_DIM2] = (q15_t) __SSAT(acc1, 16);
118 
119     /*
120      * Return to application
121      */
122     return (ARM_MATH_SUCCESS);
123 }
124 
125 
126 
arm_mat_mult_q15_3x3_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)127 __STATIC_INLINE arm_status arm_mat_mult_q15_3x3_mve(
128     const arm_matrix_instance_q15 * pSrcA,
129     const arm_matrix_instance_q15 * pSrcB,
130     arm_matrix_instance_q15 * pDst)
131 {
132     q15_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
133     q15_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
134     q15_t       *pOut = pDst->pData;   /* output data matrix pointer */
135     uint16x8_t vecColBOffs;
136     q15_t       *pInA0 = pInA;
137     q15_t       *pInA1 = pInA0 + MATRIX_DIM3;
138     q15_t       *pInA2 = pInA1 + MATRIX_DIM3;
139     q63_t        acc0, acc1, acc2;
140     q15x8_t    vecB, vecA0, vecA1, vecA2;
141     mve_pred16_t p0 = vctp16q(MATRIX_DIM3);
142 
143     vecColBOffs = vidupq_u16((uint32_t)0, 1);
144     vecColBOffs = vecColBOffs * MATRIX_DIM3;
145 
146     pInB = pSrcB->pData;
147 
148     vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
149 
150     vecA0 = vldrhq_s16(pInA0);
151     vecA1 = vldrhq_s16(pInA1);
152     vecA2 = vldrhq_s16(pInA2);
153 
154     acc0 = vmlaldavq(vecA0, vecB);
155     acc1 = vmlaldavq(vecA1, vecB);
156     acc2 = vmlaldavq(vecA2, vecB);
157 
158     acc0 = asrl(acc0, 15);
159     acc1 = asrl(acc1, 15);
160     acc2 = asrl(acc2, 15);
161 
162     pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
163     pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
164     pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
165     pOut++;
166 
167     /* move to next B column */
168     pInB = pInB + 1;
169 
170     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
171 
172     acc0 = vmlaldavq(vecA0, vecB);
173     acc1 = vmlaldavq(vecA1, vecB);
174     acc2 = vmlaldavq(vecA2, vecB);
175 
176     acc0 = asrl(acc0, 15);
177     acc1 = asrl(acc1, 15);
178     acc2 = asrl(acc2, 15);
179 
180     pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
181     pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
182     pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
183     pOut++;
184 
185     /* move to next B column */
186     pInB = pInB + 1;
187 
188     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
189 
190     acc0 = vmlaldavq(vecA0, vecB);
191     acc1 = vmlaldavq(vecA1, vecB);
192     acc2 = vmlaldavq(vecA2, vecB);
193 
194     acc0 = asrl(acc0, 15);
195     acc1 = asrl(acc1, 15);
196     acc2 = asrl(acc2, 15);
197 
198     pOut[0 * MATRIX_DIM3] = (q15_t) __SSAT(acc0, 16);
199     pOut[1 * MATRIX_DIM3] = (q15_t) __SSAT(acc1, 16);
200     pOut[2 * MATRIX_DIM3] = (q15_t) __SSAT(acc2, 16);
201     /*
202      * Return to application
203      */
204     return (ARM_MATH_SUCCESS);
205 }
206 
207 
arm_mat_mult_q15_4x4_mve(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst)208 __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
209     const arm_matrix_instance_q15 * pSrcA,
210     const arm_matrix_instance_q15 * pSrcB,
211     arm_matrix_instance_q15 * pDst)
212 {
213     q15_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
214     q15_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
215     q15_t       *pOut = pDst->pData;   /* output data matrix pointer */
216     uint16x8_t vecColBOffs;
217     q15_t       *pInA0 = pInA;
218     q15_t       *pInA1 = pInA0 + MATRIX_DIM4;
219     q15_t       *pInA2 = pInA1 + MATRIX_DIM4;
220     q15_t       *pInA3 = pInA2 + MATRIX_DIM4;
221     q63_t        acc0, acc1, acc2, acc3;
222     q15x8_t     vecB, vecA0, vecA1, vecA2, vecA3;
223     mve_pred16_t p0 = vctp16q(MATRIX_DIM4);
224 
225     vecColBOffs = vidupq_u16((uint32_t)0, 4);
226 
227     pInB = pSrcB->pData;
228 
229     vecB = vldrhq_gather_shifted_offset_z_s16((q15_t const *)pInB, vecColBOffs, p0);
230 
231     vecA0 = vldrhq_s16(pInA0);
232     vecA1 = vldrhq_s16(pInA1);
233     vecA2 = vldrhq_s16(pInA2);
234     vecA3 = vldrhq_s16(pInA3);
235 
236     acc0 = vmlaldavq(vecA0, vecB);
237     acc1 = vmlaldavq(vecA1, vecB);
238     acc2 = vmlaldavq(vecA2, vecB);
239     acc3 = vmlaldavq(vecA3, vecB);
240 
241     acc0 = asrl(acc0, 15);
242     acc1 = asrl(acc1, 15);
243     acc2 = asrl(acc2, 15);
244     acc3 = asrl(acc3, 15);
245 
246     pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
247     pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
248     pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
249     pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
250     pOut++;
251 
252     /* move to next B column */
253     pInB = pInB + 1;
254 
255     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
256 
257     acc0 = vmlaldavq(vecA0, vecB);
258     acc1 = vmlaldavq(vecA1, vecB);
259     acc2 = vmlaldavq(vecA2, vecB);
260     acc3 = vmlaldavq(vecA3, vecB);
261 
262     acc0 = asrl(acc0, 15);
263     acc1 = asrl(acc1, 15);
264     acc2 = asrl(acc2, 15);
265     acc3 = asrl(acc3, 15);
266 
267     pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
268     pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
269     pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
270     pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
271 
272     pOut++;
273 
274     /* move to next B column */
275     pInB = pInB + 1;
276 
277     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
278 
279     acc0 = vmlaldavq(vecA0, vecB);
280     acc1 = vmlaldavq(vecA1, vecB);
281     acc2 = vmlaldavq(vecA2, vecB);
282     acc3 = vmlaldavq(vecA3, vecB);
283 
284     acc0 = asrl(acc0, 15);
285     acc1 = asrl(acc1, 15);
286     acc2 = asrl(acc2, 15);
287     acc3 = asrl(acc3, 15);
288 
289     pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
290     pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
291     pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
292     pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
293 
294     pOut++;
295 
296     /* move to next B column */
297     pInB = pInB + 1;
298 
299     vecB = vldrhq_gather_shifted_offset_z_s16(pInB, vecColBOffs, p0);
300 
301     acc0 = vmlaldavq(vecA0, vecB);
302     acc1 = vmlaldavq(vecA1, vecB);
303     acc2 = vmlaldavq(vecA2, vecB);
304     acc3 = vmlaldavq(vecA3, vecB);
305 
306     acc0 = asrl(acc0, 15);
307     acc1 = asrl(acc1, 15);
308     acc2 = asrl(acc2, 15);
309     acc3 = asrl(acc3, 15);
310 
311     pOut[0 * MATRIX_DIM4] = (q15_t) __SSAT(acc0, 16);
312     pOut[1 * MATRIX_DIM4] = (q15_t) __SSAT(acc1, 16);
313     pOut[2 * MATRIX_DIM4] = (q15_t) __SSAT(acc2, 16);
314     pOut[3 * MATRIX_DIM4] = (q15_t) __SSAT(acc3, 16);
315     /*
316      * Return to application
317      */
318     return (ARM_MATH_SUCCESS);
319 }
320 
321 
arm_mat_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)322 arm_status arm_mat_mult_q15(
323     const arm_matrix_instance_q15 * pSrcA,
324     const arm_matrix_instance_q15 * pSrcB,
325     arm_matrix_instance_q15 * pDst,
326     q15_t * pState)
327 {
328     q15_t          *pInA = pSrcA->pData;        /* input data matrix pointer A */
329     q15_t          *pInB = pSrcB->pData;        /* input data matrix pointer B */
330     q15_t          *pInA2;
331     q15_t          *pInB2;
332     q15_t          *px;         /* Temporary output data matrix pointer */
333     q15_t          *px2;        /* Temporary output data matrix pointer */
334     uint32_t        numRowsA = pSrcA->numRows;  /* number of rows of input matrix A    */
335     uint32_t        numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
336     uint32_t        numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
337     uint32_t        numRowsB = pSrcB->numRows;  /* number of rows of input matrix A    */
338     uint32_t        col, i = 0u, j, row = numRowsB;     /* loop counters */
339     q15_t          *pSrcBT = pState;    /* input data matrix pointer for transpose */
340     uint32_t        blkCnt;     /* loop counters */
341     arm_status      status;                             /* Status of matrix multiplication */
342     arm_matrix_instance_q15 BT;
343 
344 #ifdef ARM_MATH_MATRIX_CHECK
345 
346     /* Check for matrix mismatch condition */
347     if ((pSrcA->numCols != pSrcB->numRows) ||
348       (pSrcA->numRows != pDst->numRows)  ||
349       (pSrcB->numCols != pDst->numCols)    )
350     {
351         /* Set status as ARM_MATH_SIZE_MISMATCH */
352         status = ARM_MATH_SIZE_MISMATCH;
353     }
354     else
355 #endif
356     {
357         /* small squared matrix specialized routines */
358         if (numRowsA == numColsB && numColsB == numColsA) {
359 
360             if (numRowsA == 1) {
361                 q63_t           sum;
362                 sum = pInA[0] * pInB[0];
363                 pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16);
364                 return (ARM_MATH_SUCCESS);
365             } else if (numRowsA == 2)
366                 return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
367             else if (numRowsA == 3)
368                 return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
369             else if (numRowsA == 4)
370                 return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
371         }
372 
373         /*
374          * Matrix transpose
375          */
376 
377         BT.numRows = numColsB;
378         BT.numCols = numRowsB;
379         BT.pData = pSrcBT;
380 
381         arm_mat_trans_q15(pSrcB, &BT);
382 
383 
384         /*
385          * Reset the variables for the usage in the following multiplication process
386          */
387         i = 0;
388         row = numRowsA >> 1;
389         px = pDst->pData;
390         px2 = px + numColsB;
391 
392         /*
393          * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
394          */
395 
396         /*
397          * row loop
398          */
399         while (row > 0u) {
400             /*
401              * For every row wise process, the column loop counter is to be initiated
402              */
403             col = numColsB >> 1;
404             /*
405              * For every row wise process, the pIn2 pointer is set
406              * to the starting address of the transposed pSrcB data
407              */
408             pInB = pSrcBT;
409             pInB2 = pInB + numRowsB;
410             j = 0;
411 
412             /*
413              * column loop
414              */
415             while (col > 0u) {
416                 q15_t const    *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
417                 q15x8_t         vecA, vecA2, vecB, vecB2;
418                 q63_t           acc0, acc1, acc2, acc3;
419 
420                 /*
421                  * Initiate the pointer pIn1 to point to the starting address of the column being processed
422                  */
423                 pInA = pSrcA->pData + i;
424                 pInA2 = pInA + numColsA;
425                 pInB = pSrcBT + j;
426                 pInB2 = pInB + numRowsB;
427 
428 
429                 pSrcAVec = (q15_t const *) pInA;
430                 pSrcA2Vec = (q15_t const *) pInA2;
431                 pSrcBVec = (q15_t const *) pInB;
432                 pSrcB2Vec = (q15_t const *) pInB2;
433 
434                 acc0 = 0LL;
435                 acc1 = 0LL;
436                 acc2 = 0LL;
437                 acc3 = 0LL;
438 
439                 vecA = vld1q(pSrcAVec);
440                 pSrcAVec += 8;
441 
442                 blkCnt = numColsA / 8;
443                 while (blkCnt > 0U) {
444                     vecB = vld1q(pSrcBVec);
445                     pSrcBVec += 8;
446                     acc0 = vmlaldavaq(acc0, vecA, vecB);
447                     vecA2 = vld1q(pSrcA2Vec);
448                     pSrcA2Vec += 8;
449                     acc1 = vmlaldavaq(acc1, vecA2, vecB);
450                     vecB2 = vld1q(pSrcB2Vec);
451                     pSrcB2Vec += 8;
452                     acc2 = vmlaldavaq(acc2, vecA, vecB2);
453                     vecA = vld1q(pSrcAVec);
454                     pSrcAVec += 8;
455                     acc3 = vmlaldavaq(acc3, vecA2, vecB2);
456 
457                     blkCnt--;
458                 }
459                 /*
460                  * tail
461                  */
462                 blkCnt = numColsA & 7;
463                 if (blkCnt > 0U) {
464                     mve_pred16_t    p0 = vctp16q(blkCnt);
465                     vecB = vld1q(pSrcBVec);
466                     acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
467                     vecA2 = vld1q(pSrcA2Vec);
468                     acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0);
469                     vecB2 = vld1q(pSrcB2Vec);
470                     acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0);
471                     vecA = vld1q(pSrcAVec);
472                     acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0);
473                 }
474 
475                 *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
476                 *px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15);
477                 *px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15);
478                 *px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15);
479                 j += numRowsB * 2;
480                 /*
481                  * Decrement the column loop counter
482                  */
483                 col--;
484 
485             }
486 
487             i = i + numColsA * 2;
488             px = px2 + (numColsB & 1u);
489             px2 = px + numColsB;
490             /*
491              * Decrement the row loop counter
492              */
493             row--;
494         }
495 
496         /*
497          * Compute remaining row and/or column below
498          */
499 
500         if (numColsB & 1u) {
501             row = numRowsA & (~0x1);    //avoid redundant computation
502             px = pDst->pData + numColsB - 1;
503             i = 0;
504 
505             /*
506              * row loop
507              */
508             while (row > 0) {
509                 q15_t const    *pSrcAVec, *pSrcBVec;
510                 q15x8_t         vecA, vecB;
511                 q63_t           acc0;
512 
513                 /*
514                  * point to last column in matrix B
515                  */
516                 pInB = pSrcBT + numRowsB * (numColsB - 1);
517                 pInA = pSrcA->pData + i;
518 
519                 pSrcAVec = (q15_t const *) pInA;
520                 pSrcBVec = (q15_t const *) pInB;
521 
522                 acc0 = 0LL;
523                 blkCnt = (numColsA) / 8;
524                 while (blkCnt > 0U) {
525                     vecA = vld1q(pSrcAVec);
526                     pSrcAVec += 8;
527                     vecB = vld1q(pSrcBVec);
528                     pSrcBVec += 8;
529                     acc0 = vmlaldavaq(acc0, vecA, vecB);
530 
531                     blkCnt--;
532                 }
533                 /*
534                  * tail
535                  */
536                 blkCnt = (numColsA & 7);
537                 if (blkCnt > 0U) {
538                     mve_pred16_t    p0 = vctp16q(blkCnt);
539                     vecA = vld1q(pSrcAVec);
540                     vecB = vld1q(pSrcBVec);
541                     acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
542                 }
543 
544                 *px = (q15_t) MVE_ASRL_SAT16(acc0, 15);
545 
546                 px += numColsB;
547 
548                 i += numColsA;
549                 /*
550                  * Decrement the row loop counter
551                  */
552                 row--;
553             }
554         }
555 
556         if (numRowsA & 1u) {
557             col = numColsB;
558             i = 0u;
559             /*
560              * point to last row in output matrix
561              */
562             px = pDst->pData + (numColsB) * (numRowsA - 1);
563             /*
564              * col loop
565              */
566             while (col > 0) {
567                 q15_t const    *pSrcAVec, *pSrcBVec;
568                 q15x8_t         vecA, vecB;
569                 q63_t           acc0;
570 
571                 /*
572                  * point to last row in matrix A
573                  */
574                 pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
575                 pInB = pSrcBT + i;
576 
577                 /*
578                  * Set the variable sum, that acts as accumulator, to zero
579                  */
580                 pSrcAVec = (q15_t const *) pInA;
581                 pSrcBVec = (q15_t const *) pInB;
582                 acc0 = 0LL;
583 
584                 blkCnt = ((numColsA) / 8);
585                 while (blkCnt > 0U) {
586                     vecA = vld1q(pSrcAVec);
587                     pSrcAVec += 8;
588                     vecB = vld1q(pSrcBVec);
589                     pSrcBVec += 8;
590                     acc0 = vmlaldavaq(acc0, vecA, vecB);
591 
592                     blkCnt--;
593                 }
594                 /*
595                  * tail
596                  */
597                 blkCnt = (numColsA & 7);
598                 if (blkCnt > 0U) {
599                     mve_pred16_t    p0 = vctp16q(blkCnt);
600                     vecA = vld1q(pSrcAVec);
601                     vecB = vld1q(pSrcBVec);
602                     acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
603                 }
604 
605                 *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
606 
607                 i += numColsA;
608 
609                 /*
610                  * Decrement the col loop counter
611                  */
612                 col--;
613             }
614         }
615 
616         /* Set status as ARM_MATH_SUCCESS */
617         status = ARM_MATH_SUCCESS;
618     }
619     /* Return to application */
620     return (status);
621 }
622 
623 #else
arm_mat_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)624 arm_status arm_mat_mult_q15(
625   const arm_matrix_instance_q15 * pSrcA,
626   const arm_matrix_instance_q15 * pSrcB,
627         arm_matrix_instance_q15 * pDst,
628         q15_t                   * pState)
629 {
630         q63_t sum;                                     /* Accumulator */
631 
632 #if defined (ARM_MATH_DSP)                             /* != CM0 */
633 
634         q15_t *pSrcBT = pState;                        /* Input data matrix pointer for transpose */
635         q15_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A of Q15 type */
636         q15_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B of Q15 type */
637         q15_t *px;                                     /* Temporary output data matrix pointer */
638         uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
639         uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
640         uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
641         uint16_t numRowsB = pSrcB->numRows;            /* Number of rows of input matrix B */
642         uint32_t col, i = 0U, row = numRowsB, colCnt;  /* Loop counters */
643         arm_status status;                             /* Status of matrix multiplication */
644 
645         q31_t inA1, inB1, inA2, inB2;
646         arm_matrix_instance_q15 BT;
647 
648 #ifdef ARM_MATH_MATRIX_CHECK
649 
650   /* Check for matrix mismatch condition */
651   if ((pSrcA->numCols != pSrcB->numRows) ||
652       (pSrcA->numRows != pDst->numRows)  ||
653       (pSrcB->numCols != pDst->numCols)    )
654   {
655     /* Set status as ARM_MATH_SIZE_MISMATCH */
656     status = ARM_MATH_SIZE_MISMATCH;
657   }
658   else
659 
660 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
661   {
662 
663     BT.numRows = numColsB;
664     BT.numCols = numRowsB;
665     BT.pData = pSrcBT;
666 
667     arm_mat_trans_q15(pSrcB,&BT);
668     /* Reset variables for usage in following multiplication process */
669     row = numRowsA;
670     i = 0U;
671     px = pDst->pData;
672 
673     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
674     /* row loop */
675     do
676     {
677       /* For every row wise process, column loop counter is to be initiated */
678       col = numColsB;
679 
680       /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
681       pInB = pSrcBT;
682 
683       /* column loop */
684       do
685       {
686         /* Set variable sum, that acts as accumulator, to zero */
687         sum = 0;
688 
689         /* Initiate pointer pInA to point to starting address of column being processed */
690         pInA = pSrcA->pData + i;
691 
692         /* Apply loop unrolling and compute 2 MACs simultaneously. */
693         colCnt = numColsA >> 2U;
694 
695         /* matrix multiplication */
696         while (colCnt > 0U)
697         {
698           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
699 
700           /* read real and imag values from pSrcA and pSrcB buffer */
701           inA1 = read_q15x2_ia (&pInA);
702           inB1 = read_q15x2_ia (&pInB);
703 
704           inA2 = read_q15x2_ia (&pInA);
705           inB2 = read_q15x2_ia (&pInB);
706 
707           /* Multiply and Accumulates */
708           sum = __SMLALD(inA1, inB1, sum);
709           sum = __SMLALD(inA2, inB2, sum);
710 
711           /* Decrement loop counter */
712           colCnt--;
713         }
714 
715         /* process remaining column samples */
716         colCnt = numColsA % 0x4U;
717 
718         while (colCnt > 0U)
719         {
720           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
721           sum += *pInA++ * *pInB++;
722 
723           /* Decrement loop counter */
724           colCnt--;
725         }
726 
727         /* Saturate and store result in destination buffer */
728         *px = (q15_t) (__SSAT((sum >> 15), 16));
729         px++;
730 
731         /* Decrement column loop counter */
732         col--;
733 
734       } while (col > 0U);
735 
736       i = i + numColsA;
737 
738       /* Decrement row loop counter */
739       row--;
740 
741     } while (row > 0U);
742 
743 #else /* #if defined (ARM_MATH_DSP) */
744 
745         q15_t *pIn1 = pSrcA->pData;                    /* Input data matrix pointer A */
746         q15_t *pIn2 = pSrcB->pData;                    /* Input data matrix pointer B */
747         q15_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A of Q15 type */
748         q15_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B of Q15 type */
749         q15_t *pOut = pDst->pData;                     /* Output data matrix pointer */
750         q15_t *px;                                     /* Temporary output data matrix pointer */
751         uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
752         uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
753         uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A    */
754         uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
755         arm_status status;                             /* Status of matrix multiplication */
756         (void)pState;
757 
758 #ifdef ARM_MATH_MATRIX_CHECK
759 
760   /* Check for matrix mismatch condition */
761   if ((pSrcA->numCols != pSrcB->numRows) ||
762       (pSrcA->numRows != pDst->numRows)  ||
763       (pSrcB->numCols != pDst->numCols)    )
764   {
765     /* Set status as ARM_MATH_SIZE_MISMATCH */
766     status = ARM_MATH_SIZE_MISMATCH;
767   }
768   else
769 
770 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
771 
772   {
773     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
774     /* row loop */
775     do
776     {
777       /* Output pointer is set to starting address of the row being processed */
778       px = pOut + i;
779 
780       /* For every row wise process, column loop counter is to be initiated */
781       col = numColsB;
782 
783       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
784       pIn2 = pSrcB->pData;
785 
786       /* column loop */
787       do
788       {
789         /* Set the variable sum, that acts as accumulator, to zero */
790         sum = 0;
791 
792         /* Initiate pointer pIn1 to point to starting address of pSrcA */
793         pIn1 = pInA;
794 
795         /* Matrix A columns number of MAC operations are to be performed */
796         colCnt = numColsA;
797 
798         /* matrix multiplication */
799         while (colCnt > 0U)
800         {
801           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
802 
803           /* Perform multiply-accumulates */
804           sum += (q31_t) * pIn1++ * *pIn2;
805           pIn2 += numColsB;
806 
807           /* Decrement loop counter */
808           colCnt--;
809         }
810 
811         /* Convert result from 34.30 to 1.15 format and store saturated value in destination buffer */
812 
813         /* Saturate and store result in destination buffer */
814         *px++ = (q15_t) __SSAT((sum >> 15), 16);
815 
816         /* Decrement column loop counter */
817         col--;
818 
819         /* Update pointer pIn2 to point to starting address of next column */
820         pIn2 = pInB + (numColsB - col);
821 
822       } while (col > 0U);
823 
824       /* Update pointer pSrcA to point to starting address of next row */
825       i = i + numColsB;
826       pInA = pInA + numColsA;
827 
828       /* Decrement row loop counter */
829       row--;
830 
831     } while (row > 0U);
832 
833 #endif /* #if defined (ARM_MATH_DSP) */
834 
835     /* Set status as ARM_MATH_SUCCESS */
836     status = ARM_MATH_SUCCESS;
837   }
838 
839   /* Return to application */
840   return (status);
841 }
842 #endif /* defined(ARM_MATH_MVEI) */
843 
844 /**
845   @} end of MatrixMult group
846  */
847