1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_f32.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 #if defined(ARM_MATH_NEON)
32 #define GROUPOFROWS 8
33 #endif
34 
35 /**
36  * @ingroup groupMatrix
37  */
38 
39 /**
40  * @defgroup MatrixMult Matrix Multiplication
41  *
42  * Multiplies two matrices.
43  *
44  * @par Multiplication of two 3x3 matrices:
45  *
46  * \f[
47  * \begin{pmatrix}
48  *  a_{1,1} & a_{1,2} & a_{1,3} \\
49  *  a_{2,1} & a_{2,2} & a_{2,3} \\
50  *  a_{3,1} & a_{3,2} & a_{3,3} \\
51  * \end{pmatrix}
52  *
53  * \begin{pmatrix}
54  *  b_{1,1} & b_{1,2} & b_{1,3} \\
55  *  b_{2,1} & b_{2,2} & b_{2,3} \\
56  *  b_{3,1} & b_{3,2} & b_{3,3} \\
57  * \end{pmatrix}
58  * =
59  * \begin{pmatrix}
60  *  a_{1,1} b_{1,1}+a_{1,2} b_{2,1}+a_{1,3} b_{3,1} & a_{1,1} b_{1,2}+a_{1,2} b_{2,2}+a_{1,3} b_{3,2} & a_{1,1} b_{1,3}+a_{1,2} b_{2,3}+a_{1,3} b_{3,3} \\
61  *  a_{2,1} b_{1,1}+a_{2,2} b_{2,1}+a_{2,3} b_{3,1} & a_{2,1} b_{1,2}+a_{2,2} b_{2,2}+a_{2,3} b_{3,2} & a_{2,1} b_{1,3}+a_{2,2} b_{2,3}+a_{2,3} b_{3,3} \\
62  *  a_{3,1} b_{1,1}+a_{3,2} b_{2,1}+a_{3,3} b_{3,1} & a_{3,1} b_{1,2}+a_{3,2} b_{2,2}+a_{3,3} b_{3,2} & a_{3,1} b_{1,3}+a_{3,2} b_{2,3}+a_{3,3} b_{3,3} \\
63  * \end{pmatrix}
64  * \f]
65 
66  * Matrix multiplication is only defined if the number of columns of the
67  * first matrix equals the number of rows of the second matrix.
68  * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
69  * in an <code>M x P</code> matrix.
70  * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
71  * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
72  * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
73  */
74 
75 
76 /**
77  * @addtogroup MatrixMult
78  * @{
79  */
80 
81 
82 
83 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
84 
85 #define MATRIX_DIM3 3
86 #define MATRIX_DIM4 4
87 
arm_mat_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)88 __STATIC_INLINE  arm_status arm_mat_mult_f32_2x2_mve(
89     const arm_matrix_instance_f32 *pSrcA,
90     const arm_matrix_instance_f32 *pSrcB,
91     arm_matrix_instance_f32 *pDst)
92 {
93     /* {a00, a00, a10, a10} */
94     static const uint32_t  offsetA0[4] = { 0, 0, 2, 2 };
95     /* {b00, b01, b00, b01} */
96     static const uint32_t  offsetB0[4] = { 0, 1, 0, 1 };
97     /* {a01, a01, a11, a11} */
98     static const uint32_t  offsetA1[4] = { 1, 1, 3, 3 };
99     /* {b10, b11, b10, b11} */
100     static const uint32_t  offsetB1[4] = { 2, 3, 2, 3 };
101 
102     uint32x4_t vecOffsA, vecOffsB;
103     f32x4_t vecInA, vecInB, vecDst;
104 
105     vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
106     vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
107 
108     vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
109     vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
110 
111     vecDst = vmulq(vecInA, vecInB);
112 
113     vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
114     vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
115 
116     vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
117     vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
118 
119     vecDst = vfmaq(vecDst, vecInA, vecInB);
120 
121     vstrwq_f32(pDst->pData, vecDst);
122 
123     return (ARM_MATH_SUCCESS);
124 
125 }
126 
127 
128 /*
129  * A  =  {{a00, a01, a02},
130  *        {a10, a11, a12},
131  *        {a20, a21, a22}}
132  * B  =  {{b00, b01, b02},
133  *        {b10, b11, b12},
134  *        {b20, b21, b22}}
135  *
136  * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
137  *        {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
138  *        {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
139  */
arm_mat_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)140 __STATIC_INLINE  arm_status arm_mat_mult_f32_3x3_mve(
141     const arm_matrix_instance_f32 *pSrcA,
142     const arm_matrix_instance_f32 *pSrcB,
143     arm_matrix_instance_f32 *pDst)
144 {
145     float32_t   *pInB = pSrcB->pData; /* input data matrix pointer B */
146     float32_t   *pInA = pSrcA->pData; /* input data matrix pointer A  */
147     float32_t   *pOut = pDst->pData;  /* output data matrix pointer */
148     float32_t   *pInA0, *pInA1, *pInA2;
149     f32x4_t    vecMac0, vecMac1, vecMac2;
150     f32x4_t    vecInB;
151     float32_t const *pSrBVec;
152 
153     pSrBVec = (float32_t const *) pInB;
154 
155     pInA0 = pInA;
156     pInA1 = pInA0 + MATRIX_DIM3;
157     pInA2 = pInA1 + MATRIX_DIM3;
158     /* enable predication to disable last (4th) vector element */
159     mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
160 
161     /*
162      * load {b0,0, b0,1, b0,2, 0}
163      */
164     vecInB = vldrwq_z_f32(pSrBVec, p0);
165     pSrBVec += MATRIX_DIM3;
166 
167     vecMac0 = vmulq(vecInB, *pInA0++);
168     vecMac1 = vmulq(vecInB, *pInA1++);
169     vecMac2 = vmulq(vecInB, *pInA2++);
170     /*
171      * load {b1,0, b1,1, b1,2, 0}
172      */
173     vecInB = vldrwq_z_f32(pSrBVec, p0);
174     pSrBVec += MATRIX_DIM3;
175 
176     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
177     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
178     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
179     /*
180      * load {b2,0, b2,1 , b2,2, 0}
181      */
182     vecInB = vldrwq_z_f32(pSrBVec, p0);
183     pSrBVec += MATRIX_DIM3;
184 
185     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
186     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
187     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
188 
189     /* partial vector stores */
190     vstrwq_p_f32(pOut, vecMac0, p0);
191     pOut += MATRIX_DIM3;
192     vstrwq_p_f32(pOut, vecMac1, p0);
193     pOut += MATRIX_DIM3;
194     vstrwq_p_f32(pOut, vecMac2, p0);
195     /*
196      * Return to application
197      */
198     return (ARM_MATH_SUCCESS);
199 }
200 
201 
202 
203 
arm_mat_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)204 __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
205     const arm_matrix_instance_f32 *pSrcA,
206     const arm_matrix_instance_f32 *pSrcB,
207     arm_matrix_instance_f32 *pDst)
208 {
209     float32_t const *pSrBVec;
210     float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
211     float32_t *pInA = pSrcA->pData; /* input data matrix pointer A  */
212     float32_t *pOut = pDst->pData;  /* output data matrix pointer */
213     float32_t *pInA0, *pInA1, *pInA2, *pInA3;
214     f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
215     f32x4_t vecInB;
216 
217     pSrBVec = (float32_t const *) pInB;
218 
219     pInA0 = pInA;
220     pInA1 = pInA0 + MATRIX_DIM4;
221     pInA2 = pInA1 + MATRIX_DIM4;
222     pInA3 = pInA2 + MATRIX_DIM4;
223     /*
224      * load {b0,0, b0,1, b0,2, b0,3}
225      */
226     vecInB = vld1q(pSrBVec);
227     pSrBVec += MATRIX_DIM4;
228 
229     vecMac0 = vmulq(vecInB, *pInA0++);
230     vecMac1 = vmulq(vecInB, *pInA1++);
231     vecMac2 = vmulq(vecInB, *pInA2++);
232     vecMac3 = vmulq(vecInB, *pInA3++);
233     /*
234      * load {b1,0, b1,1, b1,2, b1,3}
235      */
236     vecInB = vld1q(pSrBVec);
237     pSrBVec += MATRIX_DIM4;
238 
239     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
240     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
241     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
242     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
243     /*
244      * load {b2,0, b2,1, b2,2, b2,3}
245      */
246     vecInB = vld1q(pSrBVec);
247     pSrBVec += MATRIX_DIM4;
248 
249     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
250     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
251     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
252     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
253     /*
254      * load {b3,0, b3,1, b3,2, b3,3}
255      */
256     vecInB = vld1q(pSrBVec);
257     pSrBVec += MATRIX_DIM4;
258 
259     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
260     vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
261     vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
262     vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
263 
264     vst1q(pOut, vecMac0);
265     pOut += MATRIX_DIM4;
266     vst1q(pOut, vecMac1);
267     pOut += MATRIX_DIM4;
268     vst1q(pOut, vecMac2);
269     pOut += MATRIX_DIM4;
270     vst1q(pOut, vecMac3);
271     /*
272      * Return to application
273      */
274     return (ARM_MATH_SUCCESS);
275 }
276 
277 
278 /**
279  * @brief Floating-point matrix multiplication.
280  * @param[in]       *pSrcA points to the first input matrix structure
281  * @param[in]       *pSrcB points to the second input matrix structure
282  * @param[out]      *pDst points to output matrix structure
283  * @return          The function returns either
284  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
285  */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)286 arm_status arm_mat_mult_f32(
287   const arm_matrix_instance_f32 * pSrcA,
288   const arm_matrix_instance_f32 * pSrcB,
289   arm_matrix_instance_f32 * pDst)
290 {
291     float32_t  *pInB = pSrcB->pData;        /* input data matrix pointer B */
292     float32_t  *pInA = pSrcA->pData;        /* input data matrix pointer A  */
293     float32_t  *pOut = pDst->pData;         /* output data matrix pointer */
294     int         numRowsA = pSrcA->numRows;  /* number of rows of input matrix A */
295     int         numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
296     int         numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
297     uint32_t    blkCnt;                     /* loop counters */
298     uint32_t    i;
299     arm_status status;
300 
301 #ifdef ARM_MATH_MATRIX_CHECK
302 
303   /* Check for matrix mismatch condition */
304   if ((pSrcA->numCols != pSrcB->numRows) ||
305      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
306   {
307     /* Set status as ARM_MATH_SIZE_MISMATCH */
308     status = ARM_MATH_SIZE_MISMATCH;
309   }
310   else
311 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
312   {
313       /* small squared matrix specialized routines */
314     if(numRowsA == numColsB && numColsB == numColsA) {
315         if (numRowsA == 1)
316         {
317            pOut[0] = pInA[0] * pInB[0];
318            return(ARM_MATH_SUCCESS);
319         }
320         else if(numRowsA == 2)
321             return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
322         else if(numRowsA == 3)
323             return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
324         else if(numRowsA == 4)
325             return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
326     }
327 
328     /* main loop process 4 rows */
329     i = numRowsA >> 2;
330     while (i > 0U)
331     {
332         float32_t *pInA0, *pInA1, *pInA2, *pInA3;
333         float32_t *pInB0;
334         float32_t *pOut0, *pOut1, *pOut2, *pOut3;
335         f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
336         f32x4_t vecInB;
337 
338         /* pointers to 4 consecutive output rows */
339         pOut0 = pOut;
340         pOut1 = pOut0 + numColsB;
341         pOut2 = pOut1 + numColsB;
342         pOut3 = pOut2 + numColsB;
343         pInB0 = pInB;
344 
345         uint32_t  k = numColsB >> 2;
346         while (k > 0U)
347         {
348             /* pointers to 4 consecutive Matrix A rows */
349             pInA0 = pInA;
350             pInA1 = pInA0 + numColsA;
351             pInA2 = pInA1 + numColsA;
352             pInA3 = pInA2 + numColsA;
353 
354             vecMac0 = vdupq_n_f32(0.0f);
355             vecMac1 = vdupq_n_f32(0.0f);
356             vecMac2 = vdupq_n_f32(0.0f);
357             vecMac3 = vdupq_n_f32(0.0f);
358 
359             blkCnt = numColsA;
360 
361             while (blkCnt > 0U)
362             {
363                 /*
364                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
365                  */
366                 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
367 
368                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
369                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
370                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
371                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
372 
373                 pInB0 = pInB0 + numColsB;
374                 /*
375                  * Decrement the blockSize loop counter
376                  */
377                 blkCnt--;
378             }
379 
380             /* Store the results (4 x 4 block) in the destination buffer */
381             vst1q(pOut0, vecMac0);
382             pOut0 += 4;
383             vst1q(pOut1, vecMac1);
384             pOut1 += 4;
385             vst1q(pOut2, vecMac2);
386             pOut2 += 4;
387             vst1q(pOut3, vecMac3);
388             pOut3 += 4;
389 
390             /*
391              * rewind
392              */
393             pInB0 -= (numColsB * numColsA) - 4;
394             k--;
395         }
396 
397         int       colBLeft = numColsB & 3;
398         if (colBLeft)
399         {
400             pInA0 = pInA;
401             pInA1 = pInA0 + numColsA;
402             pInA2 = pInA1 + numColsA;
403             pInA3 = pInA2 + numColsA;
404             mve_pred16_t p0 = vctp32q(colBLeft);
405 
406             vecMac0 = vdupq_n_f32(0.0f);
407             vecMac1 = vdupq_n_f32(0.0f);
408             vecMac2 = vdupq_n_f32(0.0f);
409             vecMac3 = vdupq_n_f32(0.0f);
410 
411             blkCnt = numColsA;
412 
413             while (blkCnt > 0U)
414             {
415                 /*
416                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
417                  */
418                 vecInB = vldrwq_z_f32(pInB0, p0);
419 
420                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
421                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
422                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
423                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
424 
425                 pInB0 = pInB0 + numColsB;
426                 /*
427                  * Decrement the blockSize loop counter
428                  */
429                 blkCnt--;
430             }
431 
432             /* Store the results (4 x colBLeft block) in the destination buffer */
433             vstrwq_p_f32(pOut0, vecMac0, p0);
434             vstrwq_p_f32(pOut1, vecMac1, p0);
435             vstrwq_p_f32(pOut2, vecMac2, p0);
436             vstrwq_p_f32(pOut3, vecMac3, p0);
437         }
438 
439         /* move to next rows */
440         pInA += 4 * numColsA;
441         pOut += 4 * numColsB;
442         i--;
443     }
444 
445     /*
446      * non multiple of 4 rows for Matrix A
447      * process single row
448      */
449     if (numRowsA & 3)
450     {
451         i = numRowsA & 3;
452         while (i > 0U)
453         {
454             float32_t   *pInA0;
455             float32_t   *pInB0;
456             float32_t   *pOut0;
457             f32x4_t    vecInB;
458             f32x4_t    vecMac0;
459 
460             pOut0 = pOut;
461             pInB0 = pInB;
462 
463             uint32_t       k = numColsB >> 2;
464             while (k > 0U)
465             {
466                 pInA0 = pInA;
467 
468                 vecMac0 = vdupq_n_f32(0.0f);
469                 blkCnt = numColsA;
470                 while (blkCnt > 0U)
471                 {
472                     /*
473                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
474                      */
475                     vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
476 
477                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
478 
479                     pInB0 = pInB0 + numColsB;
480                     /*
481                      * Decrement the blockSize loop counter
482                      */
483                     blkCnt--;
484                 }
485 
486                 /* Store the results (1 x 4 block) in the destination buffer */
487                 vst1q(pOut0, vecMac0);
488                 pOut0 += 4;
489 
490                 /*
491                  * rewind
492                  */
493                 pInB0 -= (numColsB * numColsA) - 4;
494                 k--;
495             }
496 
497             int       colBLeft = numColsB & 3;
498             if (colBLeft)
499             {
500                 pInA0 = pInA;
501                 mve_pred16_t p0 = vctp32q(colBLeft);
502 
503                 vecMac0 = vdupq_n_f32(0.0f);
504                 blkCnt = numColsA;
505                 while (blkCnt > 0U)
506                 {
507                     /*
508                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
509                      */
510                     vecInB = vldrwq_z_f32(pInB0, p0);
511 
512                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
513 
514                     pInB0 = pInB0 + numColsB;
515                     /*
516                      * Decrement the blockSize loop counter
517                      */
518                     blkCnt--;
519                 }
520                 /* Store the results (1 x colBLeft block) in the destination buffer */
521                 vstrwq_p_f32(pOut0, vecMac0, p0);
522             }
523 
524             /* move to next row */
525             pInA += 1 * numColsA;
526             pOut += 1 * numColsB;
527             i--;
528         }
529 
530       }
531       status = ARM_MATH_SUCCESS;
532   }
533 
534   /* Return to application */
535   return (status);
536 }
537 #else
538 
539 #if defined(ARM_MATH_NEON)
540 /**
541  * @brief Floating-point matrix multiplication.
542  * @param[in]       *pSrcA points to the first input matrix structure
543  * @param[in]       *pSrcB points to the second input matrix structure
544  * @param[out]      *pDst points to output matrix structure
545  * @return          The function returns either
546  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
547  */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)548 arm_status arm_mat_mult_f32(
549   const arm_matrix_instance_f32 * pSrcA,
550   const arm_matrix_instance_f32 * pSrcB,
551   arm_matrix_instance_f32 * pDst)
552 {
553   float32_t *pIn1 = pSrcA->pData;                /* input data matrix pointer A */
554   float32_t *pIn2 = pSrcB->pData;                /* input data matrix pointer B */
555   float32_t *pInA = pSrcA->pData;                /* input data matrix pointer A  */
556   float32_t *pOut = pDst->pData;                 /* output data matrix pointer */
557   float32_t *px;                                 /* Temporary output data matrix pointer */
558   float32_t sum;                                 /* Accumulator */
559   uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A */
560   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
561   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
562 
563 
564   uint32_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt;      /* loop counters */
565   arm_status status;                             /* status of matrix multiplication */
566 
567   float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
568   float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
569   float32x2_t accum = vdup_n_f32(0);
570   float32_t *pIn1B = pSrcA->pData;
571   float32_t *pIn1C = pSrcA->pData;
572   float32_t *pIn1D = pSrcA->pData;
573   float32_t *pIn1E = pSrcA->pData;
574   float32_t *pIn1F = pSrcA->pData;
575   float32_t *pIn1G = pSrcA->pData;
576   float32_t *pIn1H = pSrcA->pData;
577 
578   float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH;                                 /* Temporary output data matrix pointer */
579   float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
580 
581 #ifdef ARM_MATH_MATRIX_CHECK
582 
583   /* Check for matrix mismatch condition */
584   if ((pSrcA->numCols != pSrcB->numRows) ||
585      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
586   {
587     /* Set status as ARM_MATH_SIZE_MISMATCH */
588     status = ARM_MATH_SIZE_MISMATCH;
589   }
590   else
591 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
592   {
593     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
594     /* Row loop */
595     rowCnt = row >> 3;
596 
597     while(rowCnt > 0)
598     {
599       /* Output pointer is set to starting address of the row being processed */
600       px = pOut + GROUPOFROWS*i;
601       pxB = px + numColsB;
602       pxC = px + 2*numColsB;
603       pxD = px + 3*numColsB;
604       pxE = px + 4*numColsB;
605       pxF = px + 5*numColsB;
606       pxG = px + 6*numColsB;
607       pxH = px + 7*numColsB;
608 
609       /* For every row wise process, the column loop counter is to be initiated */
610       col = numColsB;
611 
612       /* For every row wise process, the pIn2 pointer is set
613        ** to the starting address of the pSrcB data */
614       pIn2 = pSrcB->pData;
615 
616       j = 0U;
617 
618       /* Column loop */
619       do
620       {
621         /* Set the variable sum, that acts as accumulator, to zero */
622         sum0 = 0.0f;
623         sum1 = 0.0f;
624         sum2 = 0.0f;
625         sum3 = 0.0f;
626         sum4 = 0.0f;
627         sum5 = 0.0f;
628         sum6 = 0.0f;
629         sum7 = 0.0f;
630 
631         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
632         pIn1 = pInA;
633         pIn1B = pIn1 + numColsA;
634         pIn1C = pIn1 + 2*numColsA;
635         pIn1D = pIn1 + 3*numColsA;
636         pIn1E = pIn1 + 4*numColsA;
637         pIn1F = pIn1 + 5*numColsA;
638         pIn1G = pIn1 + 6*numColsA;
639         pIn1H = pIn1 + 7*numColsA;
640 
641         acc0 = vdupq_n_f32(0.0);
642         acc1 = vdupq_n_f32(0.0);
643         acc2 = vdupq_n_f32(0.0);
644         acc3 = vdupq_n_f32(0.0);
645         acc4 = vdupq_n_f32(0.0);
646         acc5 = vdupq_n_f32(0.0);
647         acc6 = vdupq_n_f32(0.0);
648         acc7 = vdupq_n_f32(0.0);
649 
650         /* Compute 4 MACs simultaneously. */
651         colCnt = numColsA >> 2U;
652 
653         /* Matrix multiplication */
654         while (colCnt > 0U)
655         {
656           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
657           a0V = vld1q_f32(pIn1);
658           a1V = vld1q_f32(pIn1B);
659           a2V = vld1q_f32(pIn1C);
660           a3V = vld1q_f32(pIn1D);
661           a4V = vld1q_f32(pIn1E);
662           a5V = vld1q_f32(pIn1F);
663           a6V = vld1q_f32(pIn1G);
664           a7V = vld1q_f32(pIn1H);
665 
666 	      pIn1 += 4;
667           pIn1B += 4;
668           pIn1C += 4;
669           pIn1D += 4;
670           pIn1E += 4;
671           pIn1F += 4;
672           pIn1G += 4;
673           pIn1H += 4;
674 
675           temp = vsetq_lane_f32(*pIn2,temp,0);
676           pIn2 += numColsB;
677           temp = vsetq_lane_f32(*pIn2,temp,1);
678           pIn2 += numColsB;
679           temp = vsetq_lane_f32(*pIn2,temp,2);
680           pIn2 += numColsB;
681           temp = vsetq_lane_f32(*pIn2,temp,3);
682           pIn2 += numColsB;
683 
684           acc0 = vmlaq_f32(acc0,a0V,temp);
685           acc1 = vmlaq_f32(acc1,a1V,temp);
686           acc2 = vmlaq_f32(acc2,a2V,temp);
687           acc3 = vmlaq_f32(acc3,a3V,temp);
688           acc4 = vmlaq_f32(acc4,a4V,temp);
689           acc5 = vmlaq_f32(acc5,a5V,temp);
690           acc6 = vmlaq_f32(acc6,a6V,temp);
691           acc7 = vmlaq_f32(acc7,a7V,temp);
692 
693           /* Decrement the loop count */
694           colCnt--;
695         }
696 
697         accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
698         sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
699 
700         accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
701         sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
702 
703         accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
704         sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
705 
706         accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
707         sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
708 
709         accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
710         sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
711 
712         accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
713         sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
714 
715         accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
716         sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
717 
718         accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
719         sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
720 
721         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
722          ** No loop unrolling is used. */
723         colCnt = numColsA & 3;
724 
725         while (colCnt > 0U)
726         {
727           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
728           sum0 += *pIn1++ * (*pIn2);
729           sum1 += *pIn1B++ * (*pIn2);
730           sum2 += *pIn1C++ * (*pIn2);
731           sum3 += *pIn1D++ * (*pIn2);
732           sum4 += *pIn1E++ * (*pIn2);
733           sum5 += *pIn1F++ * (*pIn2);
734           sum6 += *pIn1G++ * (*pIn2);
735           sum7 += *pIn1H++ * (*pIn2);
736           pIn2 += numColsB;
737 
738           /* Decrement the loop counter */
739           colCnt--;
740         }
741 
742         /* Store the result in the destination buffer */
743         *px++ = sum0;
744         *pxB++ = sum1;
745         *pxC++ = sum2;
746         *pxD++ = sum3;
747         *pxE++ = sum4;
748         *pxF++ = sum5;
749         *pxG++ = sum6;
750         *pxH++ = sum7;
751 
752         /* Update the pointer pIn2 to point to the  starting address of the next column */
753         j++;
754         pIn2 = pSrcB->pData + j;
755 
756         /* Decrement the column loop counter */
757         col--;
758 
759       } while (col > 0U);
760 
761       /* Update the pointer pInA to point to the  starting address of the next row */
762       i = i + numColsB;
763       pInA = pInA + GROUPOFROWS*numColsA;
764 
765       /* Decrement the row loop counter */
766       rowCnt--;
767     }
768 
769     /*
770 
771     i was the index of a group of rows computed by previous loop.
772     Now i is the index of a row since below code is computing row per row
773     and no more group of row per group of rows.
774 
775     */
776 
777     i = GROUPOFROWS*i;
778     rowCnt = row & 7;
779 
780     while(rowCnt > 0)
781     {
782       /* Output pointer is set to starting address of the row being processed */
783       px = pOut + i;
784 
785       /* For every row wise process, the column loop counter is to be initiated */
786       col = numColsB;
787 
788       /* For every row wise process, the pIn2 pointer is set
789        ** to the starting address of the pSrcB data */
790       pIn2 = pSrcB->pData;
791 
792       j = 0U;
793 
794       /* Column loop */
795       do
796       {
797         /* Set the variable sum, that acts as accumulator, to zero */
798         sum = 0.0f;
799 
800         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
801         pIn1 = pInA;
802 
803         acc0 = vdupq_n_f32(0.0);
804 
805         /* Compute 4 MACs simultaneously. */
806         colCnt = numColsA >> 2U;
807 
808         /* Matrix multiplication   */
809         while (colCnt > 0U)
810         {
811           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
812           a0V = vld1q_f32(pIn1);  // load & separate real/imag pSrcA (de-interleave 2)
813           pIn1 += 4;
814 
815           temp = vsetq_lane_f32(*pIn2,temp,0);
816           pIn2 += numColsB;
817           temp = vsetq_lane_f32(*pIn2,temp,1);
818           pIn2 += numColsB;
819           temp = vsetq_lane_f32(*pIn2,temp,2);
820           pIn2 += numColsB;
821           temp = vsetq_lane_f32(*pIn2,temp,3);
822           pIn2 += numColsB;
823 
824           acc0 = vmlaq_f32(acc0,a0V,temp);
825 
826           /* Decrement the loop count */
827           colCnt--;
828         }
829 
830         accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
831         sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
832 
833         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
834          ** No loop unrolling is used. */
835         colCnt = numColsA % 0x4U;
836 
837         while (colCnt > 0U)
838         {
839           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
840           sum += *pIn1++ * (*pIn2);
841           pIn2 += numColsB;
842 
843           /* Decrement the loop counter */
844           colCnt--;
845         }
846 
847         /* Store the result in the destination buffer */
848         *px++ = sum;
849 
850         /* Update the pointer pIn2 to point to the  starting address of the next column */
851         j++;
852         pIn2 = pSrcB->pData + j;
853 
854         /* Decrement the column loop counter */
855         col--;
856 
857       } while (col > 0U);
858 
859 
860       /* Update the pointer pInA to point to the  starting address of the next row */
861       i = i + numColsB;
862       pInA = pInA + numColsA;
863 
864       /* Decrement the row loop counter */
865       rowCnt--;
866 
867     }
868     /* Set status as ARM_MATH_SUCCESS */
869     status = ARM_MATH_SUCCESS;
870   }
871 
872   /* Return to application */
873   return (status);
874 }
875 #else
876 /**
877  * @brief Floating-point matrix multiplication.
878  * @param[in]       *pSrcA points to the first input matrix structure
879  * @param[in]       *pSrcB points to the second input matrix structure
880  * @param[out]      *pDst points to output matrix structure
881  * @return          The function returns either
882  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
883  */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)884 arm_status arm_mat_mult_f32(
885   const arm_matrix_instance_f32 * pSrcA,
886   const arm_matrix_instance_f32 * pSrcB,
887         arm_matrix_instance_f32 * pDst)
888 {
889   float32_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
890   float32_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
891   float32_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
892   float32_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
893   float32_t *pOut = pDst->pData;                 /* Output data matrix pointer */
894   float32_t *px;                                 /* Temporary output data matrix pointer */
895   float32_t sum;                                 /* Accumulator */
896   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
897   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
898   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
899   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
900   arm_status status;                             /* Status of matrix multiplication */
901 
902 #ifdef ARM_MATH_MATRIX_CHECK
903 
904   /* Check for matrix mismatch condition */
905   if ((pSrcA->numCols != pSrcB->numRows) ||
906       (pSrcA->numRows != pDst->numRows)  ||
907       (pSrcB->numCols != pDst->numCols)    )
908   {
909     /* Set status as ARM_MATH_SIZE_MISMATCH */
910     status = ARM_MATH_SIZE_MISMATCH;
911   }
912   else
913 
914 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
915 
916   {
917     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
918     /* row loop */
919     do
920     {
921       /* Output pointer is set to starting address of row being processed */
922       px = pOut + i;
923 
924       /* For every row wise process, column loop counter is to be initiated */
925       col = numColsB;
926 
927       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
928       pIn2 = pSrcB->pData;
929 
930       /* column loop */
931       do
932       {
933         /* Set the variable sum, that acts as accumulator, to zero */
934         sum = 0.0f;
935 
936         /* Initialize pointer pIn1 to point to starting address of column being processed */
937         pIn1 = pInA;
938 
939 #if defined (ARM_MATH_LOOPUNROLL)
940 
941         /* Loop unrolling: Compute 4 MACs at a time. */
942         colCnt = numColsA >> 2U;
943 
944         /* matrix multiplication */
945         while (colCnt > 0U)
946         {
947           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
948 
949           /* Perform the multiply-accumulates */
950           sum += *pIn1++ * *pIn2;
951           pIn2 += numColsB;
952 
953           sum += *pIn1++ * *pIn2;
954           pIn2 += numColsB;
955 
956           sum += *pIn1++ * *pIn2;
957           pIn2 += numColsB;
958 
959           sum += *pIn1++ * *pIn2;
960           pIn2 += numColsB;
961 
962           /* Decrement loop counter */
963           colCnt--;
964         }
965 
966         /* Loop unrolling: Compute remaining MACs */
967         colCnt = numColsA % 0x4U;
968 
969 #else
970 
971         /* Initialize cntCnt with number of columns */
972         colCnt = numColsA;
973 
974 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
975 
976         while (colCnt > 0U)
977         {
978           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
979 
980           /* Perform the multiply-accumulates */
981           sum += *pIn1++ * *pIn2;
982           pIn2 += numColsB;
983 
984           /* Decrement loop counter */
985           colCnt--;
986         }
987 
988         /* Store result in destination buffer */
989         *px++ = sum;
990 
991         /* Decrement column loop counter */
992         col--;
993 
994         /* Update pointer pIn2 to point to starting address of next column */
995         pIn2 = pInB + (numColsB - col);
996 
997       } while (col > 0U);
998 
999       /* Update pointer pInA to point to starting address of next row */
1000       i = i + numColsB;
1001       pInA = pInA + numColsA;
1002 
1003       /* Decrement row loop counter */
1004       row--;
1005 
1006     } while (row > 0U);
1007 
1008     /* Set status as ARM_MATH_SUCCESS */
1009     status = ARM_MATH_SUCCESS;
1010   }
1011 
1012   /* Return to application */
1013   return (status);
1014 }
1015 
1016 #endif /* #if defined(ARM_MATH_NEON) */
1017 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
1018 
1019 /**
1020  * @} end of MatrixMult group
1021  */
1022