1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_f32.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32  * @ingroup groupMatrix
33  */
34 
35 /**
36  * @defgroup MatrixMult Matrix Multiplication
37  *
38  * Multiplies two matrices.
39  *
40  * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
41 
42  * Matrix multiplication is only defined if the number of columns of the
43  * first matrix equals the number of rows of the second matrix.
44  * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
45  * in an <code>M x P</code> matrix.
46  * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
47  * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
48  * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
49  */
50 
51 
52 /**
53  * @addtogroup MatrixMult
54  * @{
55  */
56 
57 /**
58  * @brief Floating-point matrix multiplication.
59  * @param[in]       *pSrcA points to the first input matrix structure
60  * @param[in]       *pSrcB points to the second input matrix structure
61  * @param[out]      *pDst points to output matrix structure
62  * @return     		The function returns either
63  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
64  */
65 
66 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
67 
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70 
arm_mat_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)71 __STATIC_INLINE  arm_status arm_mat_mult_f32_2x2_mve(
72     const arm_matrix_instance_f32 *pSrcA,
73     const arm_matrix_instance_f32 *pSrcB,
74     arm_matrix_instance_f32 *pDst)
75 {
76     /* {a00, a00, a10, a10} */
77     static const uint32_t  offsetA0[4] = { 0, 0, 2, 2 };
78     /* {b00, b01, b00, b01} */
79     static const uint32_t  offsetB0[4] = { 0, 1, 0, 1 };
80     /* {a01, a01, a11, a11} */
81     static const uint32_t  offsetA1[4] = { 1, 1, 3, 3 };
82     /* {b10, b11, b10, b11} */
83     static const uint32_t  offsetB1[4] = { 2, 3, 2, 3 };
84 
85     uint32x4_t vecOffsA, vecOffsB;
86     f32x4_t vecInA, vecInB, vecDst;
87 
88     vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
89     vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
90 
91     vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
92     vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
93 
94     vecDst = vmulq(vecInA, vecInB);
95 
96     vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
97     vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
98 
99     vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
100     vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
101 
102     vecDst = vfmaq(vecDst, vecInA, vecInB);
103 
104     vstrwq_f32(pDst->pData, vecDst);
105 
106     return (ARM_MATH_SUCCESS);
107 
108 }
109 
110 
111 /*
112  * A  =  {{a00, a01, a02},
113  *        {a10, a11, a12},
114  *        {a20, a21, a22}}
115  * B  =  {{b00, b01, b02},
116  *        {b10, b11, b12},
117  *        {b20, b21, b22}}
118  *
119  * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
120  *        {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
121  *        {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
122  */
arm_mat_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)123 __STATIC_INLINE  arm_status arm_mat_mult_f32_3x3_mve(
124     const arm_matrix_instance_f32 *pSrcA,
125     const arm_matrix_instance_f32 *pSrcB,
126     arm_matrix_instance_f32 *pDst)
127 {
128     float32_t   *pInB = pSrcB->pData; /* input data matrix pointer B */
129     float32_t   *pInA = pSrcA->pData; /* input data matrix pointer A  */
130     float32_t   *pOut = pDst->pData;  /* output data matrix pointer */
131     float32_t   *pInA0, *pInA1, *pInA2;
132     f32x4_t    vecMac0, vecMac1, vecMac2;
133     f32x4_t    vecInB;
134     float32_t const *pSrBVec;
135 
136     pSrBVec = (float32_t const *) pInB;
137 
138     pInA0 = pInA;
139     pInA1 = pInA0 + MATRIX_DIM3;
140     pInA2 = pInA1 + MATRIX_DIM3;
141     /* enable predication to disable last (4th) vector element */
142     mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
143 
144     /*
145      * load {b0,0, b0,1, b0,2, 0}
146      */
147     vecInB = vldrwq_z_f32(pSrBVec, p0);
148     pSrBVec += MATRIX_DIM3;
149 
150     vecMac0 = vmulq(vecInB, *pInA0++);
151     vecMac1 = vmulq(vecInB, *pInA1++);
152     vecMac2 = vmulq(vecInB, *pInA2++);
153     /*
154      * load {b1,0, b1,1, b1,2, 0}
155      */
156     vecInB = vldrwq_z_f32(pSrBVec, p0);
157     pSrBVec += MATRIX_DIM3;
158 
159     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
160     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
161     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
162     /*
163      * load {b2,0, b2,1 , b2,2, 0}
164      */
165     vecInB = vldrwq_z_f32(pSrBVec, p0);
166     pSrBVec += MATRIX_DIM3;
167 
168     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
169     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
170     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
171 
172     /* partial vector stores */
173     vstrwq_p_f32(pOut, vecMac0, p0);
174     pOut += MATRIX_DIM3;
175     vstrwq_p_f32(pOut, vecMac1, p0);
176     pOut += MATRIX_DIM3;
177     vstrwq_p_f32(pOut, vecMac2, p0);
178     /*
179      * Return to application
180      */
181     return (ARM_MATH_SUCCESS);
182 }
183 
184 
185 
186 
arm_mat_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)187 __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
188     const arm_matrix_instance_f32 *pSrcA,
189     const arm_matrix_instance_f32 *pSrcB,
190     arm_matrix_instance_f32 *pDst)
191 {
192     float32_t const *pSrBVec;
193     float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
194     float32_t *pInA = pSrcA->pData; /* input data matrix pointer A  */
195     float32_t *pOut = pDst->pData;  /* output data matrix pointer */
196     float32_t *pInA0, *pInA1, *pInA2, *pInA3;
197     f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
198     f32x4_t vecInB;
199 
200     pSrBVec = (float32_t const *) pInB;
201 
202     pInA0 = pInA;
203     pInA1 = pInA0 + MATRIX_DIM4;
204     pInA2 = pInA1 + MATRIX_DIM4;
205     pInA3 = pInA2 + MATRIX_DIM4;
206     /*
207      * load {b0,0, b0,1, b0,2, b0,3}
208      */
209     vecInB = vld1q(pSrBVec);
210     pSrBVec += MATRIX_DIM4;
211 
212     vecMac0 = vmulq(vecInB, *pInA0++);
213     vecMac1 = vmulq(vecInB, *pInA1++);
214     vecMac2 = vmulq(vecInB, *pInA2++);
215     vecMac3 = vmulq(vecInB, *pInA3++);
216     /*
217      * load {b1,0, b1,1, b1,2, b1,3}
218      */
219     vecInB = vld1q(pSrBVec);
220     pSrBVec += MATRIX_DIM4;
221 
222     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
223     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
224     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
225     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
226     /*
227      * load {b2,0, b2,1, b2,2, b2,3}
228      */
229     vecInB = vld1q(pSrBVec);
230     pSrBVec += MATRIX_DIM4;
231 
232     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
233     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
234     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
235     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
236     /*
237      * load {b3,0, b3,1, b3,2, b3,3}
238      */
239     vecInB = vld1q(pSrBVec);
240     pSrBVec += MATRIX_DIM4;
241 
242     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
243     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
244     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
245     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
246 
247     vst1q(pOut, vecMac0);
248     pOut += MATRIX_DIM4;
249     vst1q(pOut, vecMac1);
250     pOut += MATRIX_DIM4;
251     vst1q(pOut, vecMac2);
252     pOut += MATRIX_DIM4;
253     vst1q(pOut, vecMac3);
254     /*
255      * Return to application
256      */
257     return (ARM_MATH_SUCCESS);
258 }
259 
260 
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)261 arm_status arm_mat_mult_f32(
262   const arm_matrix_instance_f32 * pSrcA,
263   const arm_matrix_instance_f32 * pSrcB,
264   arm_matrix_instance_f32 * pDst)
265 {
266     float32_t  *pInB = pSrcB->pData;        /* input data matrix pointer B */
267     float32_t  *pInA = pSrcA->pData;        /* input data matrix pointer A  */
268     float32_t  *pOut = pDst->pData;         /* output data matrix pointer */
269     int         numRowsA = pSrcA->numRows;  /* number of rows of input matrix A */
270     int         numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
271     int         numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
272     uint32_t    blkCnt;                     /* loop counters */
273     uint32_t    i;
274     arm_status status;
275 
276 #ifdef ARM_MATH_MATRIX_CHECK
277 
278   /* Check for matrix mismatch condition */
279   if ((pSrcA->numCols != pSrcB->numRows) ||
280      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
281   {
282     /* Set status as ARM_MATH_SIZE_MISMATCH */
283     status = ARM_MATH_SIZE_MISMATCH;
284   }
285   else
286 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
287   {
288       /* small squared matrix specialized routines */
289     if(numRowsA == numColsB && numColsB == numColsA) {
290         if (numRowsA == 1)
291         {
292            pOut[0] = pInA[0] * pInB[0];
293            return(ARM_MATH_SUCCESS);
294         }
295         else if(numRowsA == 2)
296             return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
297         else if(numRowsA == 3)
298             return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
299         else if(numRowsA == 4)
300             return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
301     }
302 
303     /* main loop process 4 rows */
304     i = numRowsA >> 2;
305     while (i > 0U)
306     {
307         float32_t *pInA0, *pInA1, *pInA2, *pInA3;
308         float32_t *pInB0;
309         float32_t *pOut0, *pOut1, *pOut2, *pOut3;
310         f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
311         f32x4_t vecInB;
312 
313         /* pointers to 4 consecutive output rows */
314         pOut0 = pOut;
315         pOut1 = pOut0 + numColsB;
316         pOut2 = pOut1 + numColsB;
317         pOut3 = pOut2 + numColsB;
318         pInB0 = pInB;
319 
320         uint32_t  k = numColsB >> 2;
321         while (k > 0U)
322         {
323             /* pointers to 4 consecutive Matrix A rows */
324             pInA0 = pInA;
325             pInA1 = pInA0 + numColsA;
326             pInA2 = pInA1 + numColsA;
327             pInA3 = pInA2 + numColsA;
328 
329             vecMac0 = vdupq_n_f32(0.0f);
330             vecMac1 = vdupq_n_f32(0.0f);
331             vecMac2 = vdupq_n_f32(0.0f);
332             vecMac3 = vdupq_n_f32(0.0f);
333 
334             blkCnt = numColsA;
335 
336             while (blkCnt > 0U)
337             {
338                 /*
339                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
340                  */
341                 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
342 
343                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
344                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
345                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
346                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
347 
348                 pInB0 = pInB0 + numColsB;
349                 /*
350                  * Decrement the blockSize loop counter
351                  */
352                 blkCnt--;
353             }
354 
355             /* Store the results (4 x 4 block) in the destination buffer */
356             vst1q(pOut0, vecMac0);
357             pOut0 += 4;
358             vst1q(pOut1, vecMac1);
359             pOut1 += 4;
360             vst1q(pOut2, vecMac2);
361             pOut2 += 4;
362             vst1q(pOut3, vecMac3);
363             pOut3 += 4;
364 
365             /*
366              * rewind
367              */
368             pInB0 -= (numColsB * numColsA) - 4;
369             k--;
370         }
371 
372         int       colBLeft = numColsB & 3;
373         if (colBLeft)
374         {
375             pInA0 = pInA;
376             pInA1 = pInA0 + numColsA;
377             pInA2 = pInA1 + numColsA;
378             pInA3 = pInA2 + numColsA;
379             mve_pred16_t p0 = vctp32q(colBLeft);
380 
381             vecMac0 = vdupq_n_f32(0.0f);
382             vecMac1 = vdupq_n_f32(0.0f);
383             vecMac2 = vdupq_n_f32(0.0f);
384             vecMac3 = vdupq_n_f32(0.0f);
385 
386             blkCnt = numColsA;
387 
388             while (blkCnt > 0U)
389             {
390                 /*
391                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
392                  */
393                 vecInB = vldrwq_z_f32(pInB0, p0);
394 
395                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
396                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
397                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
398                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
399 
400                 pInB0 = pInB0 + numColsB;
401                 /*
402                  * Decrement the blockSize loop counter
403                  */
404                 blkCnt--;
405             }
406 
407             /* Store the results (4 x colBLeft block) in the destination buffer */
408             vstrwq_p_f32(pOut0, vecMac0, p0);
409             vstrwq_p_f32(pOut1, vecMac1, p0);
410             vstrwq_p_f32(pOut2, vecMac2, p0);
411             vstrwq_p_f32(pOut3, vecMac3, p0);
412         }
413 
414         /* move to next rows */
415         pInA += 4 * numColsA;
416         pOut += 4 * numColsB;
417         i--;
418     }
419 
420     /*
421      * non multiple of 4 rows for Matrix A
422      * process single row
423      */
424     if (numRowsA & 3)
425     {
426         i = numRowsA & 3;
427         while (i > 0U)
428         {
429             float32_t   *pInA0;
430             float32_t   *pInB0;
431             float32_t   *pOut0;
432             f32x4_t    vecInB;
433             f32x4_t    vecMac0;
434 
435             pOut0 = pOut;
436             pInB0 = pInB;
437 
438             uint32_t       k = numColsB >> 2;
439             while (k > 0U)
440             {
441                 pInA0 = pInA;
442 
443                 vecMac0 = vdupq_n_f32(0.0f);
444                 blkCnt = numColsA;
445                 while (blkCnt > 0U)
446                 {
447                     /*
448                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
449                      */
450                     vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
451 
452                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
453 
454                     pInB0 = pInB0 + numColsB;
455                     /*
456                      * Decrement the blockSize loop counter
457                      */
458                     blkCnt--;
459                 }
460 
461                 /* Store the results (1 x 4 block) in the destination buffer */
462                 vst1q(pOut0, vecMac0);
463                 pOut0 += 4;
464 
465                 /*
466                  * rewind
467                  */
468                 pInB0 -= (numColsB * numColsA) - 4;
469                 k--;
470             }
471 
472             int       colBLeft = numColsB & 3;
473             if (colBLeft)
474             {
475                 pInA0 = pInA;
476                 mve_pred16_t p0 = vctp32q(colBLeft);
477 
478                 vecMac0 = vdupq_n_f32(0.0f);
479                 blkCnt = numColsA;
480                 while (blkCnt > 0U)
481                 {
482                     /*
483                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
484                      */
485                     vecInB = vldrwq_z_f32(pInB0, p0);
486 
487                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
488 
489                     pInB0 = pInB0 + numColsB;
490                     /*
491                      * Decrement the blockSize loop counter
492                      */
493                     blkCnt--;
494                 }
495                 /* Store the results (1 x colBLeft block) in the destination buffer */
496                 vstrwq_p_f32(pOut0, vecMac0, p0);
497             }
498 
499             /* move to next row */
500             pInA += 1 * numColsA;
501             pOut += 1 * numColsB;
502             i--;
503         }
504 
505       }
506       status = ARM_MATH_SUCCESS;
507   }
508 
509   /* Return to application */
510   return (status);
511 }
512 #else
513 
514 #if defined(ARM_MATH_NEON)
515 
516 #define GROUPOFROWS 8
517 
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)518 arm_status arm_mat_mult_f32(
519   const arm_matrix_instance_f32 * pSrcA,
520   const arm_matrix_instance_f32 * pSrcB,
521   arm_matrix_instance_f32 * pDst)
522 {
523   float32_t *pIn1 = pSrcA->pData;                /* input data matrix pointer A */
524   float32_t *pIn2 = pSrcB->pData;                /* input data matrix pointer B */
525   float32_t *pInA = pSrcA->pData;                /* input data matrix pointer A  */
526   float32_t *pOut = pDst->pData;                 /* output data matrix pointer */
527   float32_t *px;                                 /* Temporary output data matrix pointer */
528   float32_t sum;                                 /* Accumulator */
529   uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A */
530   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
531   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
532 
533 
534   uint16_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt;      /* loop counters */
535   arm_status status;                             /* status of matrix multiplication */
536 
537   float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
538   float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
539   float32x2_t accum = vdup_n_f32(0);
540   float32_t *pIn1B = pSrcA->pData;
541   float32_t *pIn1C = pSrcA->pData;
542   float32_t *pIn1D = pSrcA->pData;
543   float32_t *pIn1E = pSrcA->pData;
544   float32_t *pIn1F = pSrcA->pData;
545   float32_t *pIn1G = pSrcA->pData;
546   float32_t *pIn1H = pSrcA->pData;
547 
548   float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH;                                 /* Temporary output data matrix pointer */
549   float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
550 
551 #ifdef ARM_MATH_MATRIX_CHECK
552 
553   /* Check for matrix mismatch condition */
554   if ((pSrcA->numCols != pSrcB->numRows) ||
555      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
556   {
557     /* Set status as ARM_MATH_SIZE_MISMATCH */
558     status = ARM_MATH_SIZE_MISMATCH;
559   }
560   else
561 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
562   {
563     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
564     /* Row loop */
565     rowCnt = row >> 3;
566 
567     while(rowCnt > 0)
568     {
569       /* Output pointer is set to starting address of the row being processed */
570       px = pOut + GROUPOFROWS*i;
571       pxB = px + numColsB;
572       pxC = px + 2*numColsB;
573       pxD = px + 3*numColsB;
574       pxE = px + 4*numColsB;
575       pxF = px + 5*numColsB;
576       pxG = px + 6*numColsB;
577       pxH = px + 7*numColsB;
578 
579       /* For every row wise process, the column loop counter is to be initiated */
580       col = numColsB;
581 
582       /* For every row wise process, the pIn2 pointer is set
583        ** to the starting address of the pSrcB data */
584       pIn2 = pSrcB->pData;
585 
586       j = 0U;
587 
588       /* Column loop */
589       do
590       {
591         /* Set the variable sum, that acts as accumulator, to zero */
592         sum0 = 0.0f;
593         sum1 = 0.0f;
594         sum2 = 0.0f;
595         sum3 = 0.0f;
596         sum4 = 0.0f;
597         sum5 = 0.0f;
598         sum6 = 0.0f;
599         sum7 = 0.0f;
600 
601         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
602         pIn1 = pInA;
603         pIn1B = pIn1 + numColsA;
604         pIn1C = pIn1 + 2*numColsA;
605         pIn1D = pIn1 + 3*numColsA;
606         pIn1E = pIn1 + 4*numColsA;
607         pIn1F = pIn1 + 5*numColsA;
608         pIn1G = pIn1 + 6*numColsA;
609         pIn1H = pIn1 + 7*numColsA;
610 
611         acc0 = vdupq_n_f32(0.0);
612         acc1 = vdupq_n_f32(0.0);
613         acc2 = vdupq_n_f32(0.0);
614         acc3 = vdupq_n_f32(0.0);
615         acc4 = vdupq_n_f32(0.0);
616         acc5 = vdupq_n_f32(0.0);
617         acc6 = vdupq_n_f32(0.0);
618         acc7 = vdupq_n_f32(0.0);
619 
620         /* Compute 4 MACs simultaneously. */
621         colCnt = numColsA >> 2U;
622 
623         /* Matrix multiplication */
624         while (colCnt > 0U)
625         {
626           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
627           a0V = vld1q_f32(pIn1);
628           a1V = vld1q_f32(pIn1B);
629           a2V = vld1q_f32(pIn1C);
630           a3V = vld1q_f32(pIn1D);
631           a4V = vld1q_f32(pIn1E);
632           a5V = vld1q_f32(pIn1F);
633           a6V = vld1q_f32(pIn1G);
634           a7V = vld1q_f32(pIn1H);
635 
636 	      pIn1 += 4;
637           pIn1B += 4;
638           pIn1C += 4;
639           pIn1D += 4;
640           pIn1E += 4;
641           pIn1F += 4;
642           pIn1G += 4;
643           pIn1H += 4;
644 
645           temp = vsetq_lane_f32(*pIn2,temp,0);
646           pIn2 += numColsB;
647           temp = vsetq_lane_f32(*pIn2,temp,1);
648           pIn2 += numColsB;
649           temp = vsetq_lane_f32(*pIn2,temp,2);
650           pIn2 += numColsB;
651           temp = vsetq_lane_f32(*pIn2,temp,3);
652           pIn2 += numColsB;
653 
654           acc0 = vmlaq_f32(acc0,a0V,temp);
655           acc1 = vmlaq_f32(acc1,a1V,temp);
656           acc2 = vmlaq_f32(acc2,a2V,temp);
657           acc3 = vmlaq_f32(acc3,a3V,temp);
658           acc4 = vmlaq_f32(acc4,a4V,temp);
659           acc5 = vmlaq_f32(acc5,a5V,temp);
660           acc6 = vmlaq_f32(acc6,a6V,temp);
661           acc7 = vmlaq_f32(acc7,a7V,temp);
662 
663           /* Decrement the loop count */
664           colCnt--;
665         }
666 
667         accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
668         sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
669 
670         accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
671         sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
672 
673         accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
674         sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
675 
676         accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
677         sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
678 
679         accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
680         sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
681 
682         accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
683         sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
684 
685         accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
686         sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
687 
688         accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
689         sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
690 
691         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
692          ** No loop unrolling is used. */
693         colCnt = numColsA & 3;
694 
695         while (colCnt > 0U)
696         {
697           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
698           sum0 += *pIn1++ * (*pIn2);
699           sum1 += *pIn1B++ * (*pIn2);
700           sum2 += *pIn1C++ * (*pIn2);
701           sum3 += *pIn1D++ * (*pIn2);
702           sum4 += *pIn1E++ * (*pIn2);
703           sum5 += *pIn1F++ * (*pIn2);
704           sum6 += *pIn1G++ * (*pIn2);
705           sum7 += *pIn1H++ * (*pIn2);
706           pIn2 += numColsB;
707 
708           /* Decrement the loop counter */
709           colCnt--;
710         }
711 
712         /* Store the result in the destination buffer */
713         *px++ = sum0;
714         *pxB++ = sum1;
715         *pxC++ = sum2;
716         *pxD++ = sum3;
717         *pxE++ = sum4;
718         *pxF++ = sum5;
719         *pxG++ = sum6;
720         *pxH++ = sum7;
721 
722         /* Update the pointer pIn2 to point to the  starting address of the next column */
723         j++;
724         pIn2 = pSrcB->pData + j;
725 
726         /* Decrement the column loop counter */
727         col--;
728 
729       } while (col > 0U);
730 
731       /* Update the pointer pInA to point to the  starting address of the next row */
732       i = i + numColsB;
733       pInA = pInA + GROUPOFROWS*numColsA;
734 
735       /* Decrement the row loop counter */
736       rowCnt--;
737     }
738 
739     /*
740 
741     i was the index of a group of rows computed by previous loop.
742     Now i is the index of a row since below code is computing row per row
743     and no more group of row per group of rows.
744 
745     */
746 
747     i = GROUPOFROWS*i;
748     rowCnt = row & 7;
749 
750     while(rowCnt > 0)
751     {
752       /* Output pointer is set to starting address of the row being processed */
753       px = pOut + i;
754 
755       /* For every row wise process, the column loop counter is to be initiated */
756       col = numColsB;
757 
758       /* For every row wise process, the pIn2 pointer is set
759        ** to the starting address of the pSrcB data */
760       pIn2 = pSrcB->pData;
761 
762       j = 0U;
763 
764       /* Column loop */
765       do
766       {
767         /* Set the variable sum, that acts as accumulator, to zero */
768         sum = 0.0f;
769 
770         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
771         pIn1 = pInA;
772 
773         acc0 = vdupq_n_f32(0.0);
774 
775         /* Compute 4 MACs simultaneously. */
776         colCnt = numColsA >> 2U;
777 
778         /* Matrix multiplication   */
779         while (colCnt > 0U)
780         {
781           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
782           a0V = vld1q_f32(pIn1);  // load & separate real/imag pSrcA (de-interleave 2)
783           pIn1 += 4;
784 
785           temp = vsetq_lane_f32(*pIn2,temp,0);
786           pIn2 += numColsB;
787           temp = vsetq_lane_f32(*pIn2,temp,1);
788           pIn2 += numColsB;
789           temp = vsetq_lane_f32(*pIn2,temp,2);
790           pIn2 += numColsB;
791           temp = vsetq_lane_f32(*pIn2,temp,3);
792           pIn2 += numColsB;
793 
794           acc0 = vmlaq_f32(acc0,a0V,temp);
795 
796           /* Decrement the loop count */
797           colCnt--;
798         }
799 
800         accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
801         sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
802 
803         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
804          ** No loop unrolling is used. */
805         colCnt = numColsA % 0x4U;
806 
807         while (colCnt > 0U)
808         {
809           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
810           sum += *pIn1++ * (*pIn2);
811           pIn2 += numColsB;
812 
813           /* Decrement the loop counter */
814           colCnt--;
815         }
816 
817         /* Store the result in the destination buffer */
818         *px++ = sum;
819 
820         /* Update the pointer pIn2 to point to the  starting address of the next column */
821         j++;
822         pIn2 = pSrcB->pData + j;
823 
824         /* Decrement the column loop counter */
825         col--;
826 
827       } while (col > 0U);
828 
829 
830       /* Update the pointer pInA to point to the  starting address of the next row */
831       i = i + numColsB;
832       pInA = pInA + numColsA;
833 
834       /* Decrement the row loop counter */
835       rowCnt--;
836 
837     }
838     /* Set status as ARM_MATH_SUCCESS */
839     status = ARM_MATH_SUCCESS;
840   }
841 
842   /* Return to application */
843   return (status);
844 }
845 #else
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)846 arm_status arm_mat_mult_f32(
847   const arm_matrix_instance_f32 * pSrcA,
848   const arm_matrix_instance_f32 * pSrcB,
849         arm_matrix_instance_f32 * pDst)
850 {
851   float32_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
852   float32_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
853   float32_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
854   float32_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
855   float32_t *pOut = pDst->pData;                 /* Output data matrix pointer */
856   float32_t *px;                                 /* Temporary output data matrix pointer */
857   float32_t sum;                                 /* Accumulator */
858   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
859   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
860   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
861   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
862   arm_status status;                             /* Status of matrix multiplication */
863 
864 #ifdef ARM_MATH_MATRIX_CHECK
865 
866   /* Check for matrix mismatch condition */
867   if ((pSrcA->numCols != pSrcB->numRows) ||
868       (pSrcA->numRows != pDst->numRows)  ||
869       (pSrcB->numCols != pDst->numCols)    )
870   {
871     /* Set status as ARM_MATH_SIZE_MISMATCH */
872     status = ARM_MATH_SIZE_MISMATCH;
873   }
874   else
875 
876 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
877 
878   {
879     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
880     /* row loop */
881     do
882     {
883       /* Output pointer is set to starting address of row being processed */
884       px = pOut + i;
885 
886       /* For every row wise process, column loop counter is to be initiated */
887       col = numColsB;
888 
889       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
890       pIn2 = pSrcB->pData;
891 
892       /* column loop */
893       do
894       {
895         /* Set the variable sum, that acts as accumulator, to zero */
896         sum = 0.0f;
897 
898         /* Initialize pointer pIn1 to point to starting address of column being processed */
899         pIn1 = pInA;
900 
901 #if defined (ARM_MATH_LOOPUNROLL)
902 
903         /* Loop unrolling: Compute 4 MACs at a time. */
904         colCnt = numColsA >> 2U;
905 
906         /* matrix multiplication */
907         while (colCnt > 0U)
908         {
909           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
910 
911           /* Perform the multiply-accumulates */
912           sum += *pIn1++ * *pIn2;
913           pIn2 += numColsB;
914 
915           sum += *pIn1++ * *pIn2;
916           pIn2 += numColsB;
917 
918           sum += *pIn1++ * *pIn2;
919           pIn2 += numColsB;
920 
921           sum += *pIn1++ * *pIn2;
922           pIn2 += numColsB;
923 
924           /* Decrement loop counter */
925           colCnt--;
926         }
927 
928         /* Loop unrolling: Compute remaining MACs */
929         colCnt = numColsA % 0x4U;
930 
931 #else
932 
933         /* Initialize cntCnt with number of columns */
934         colCnt = numColsA;
935 
936 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
937 
938         while (colCnt > 0U)
939         {
940           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
941 
942           /* Perform the multiply-accumulates */
943           sum += *pIn1++ * *pIn2;
944           pIn2 += numColsB;
945 
946           /* Decrement loop counter */
947           colCnt--;
948         }
949 
950         /* Store result in destination buffer */
951         *px++ = sum;
952 
953         /* Decrement column loop counter */
954         col--;
955 
956         /* Update pointer pIn2 to point to starting address of next column */
957         pIn2 = pInB + (numColsB - col);
958 
959       } while (col > 0U);
960 
961       /* Update pointer pInA to point to starting address of next row */
962       i = i + numColsB;
963       pInA = pInA + numColsA;
964 
965       /* Decrement row loop counter */
966       row--;
967 
968     } while (row > 0U);
969 
970     /* Set status as ARM_MATH_SUCCESS */
971     status = ARM_MATH_SUCCESS;
972   }
973 
974   /* Return to application */
975   return (status);
976 }
977 
978 #endif /* #if defined(ARM_MATH_NEON) */
979 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
980 
981 /**
982  * @} end of MatrixMult group
983  */
984