1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_f64.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        10 August 2022
7  * $Revision:    V1.9.1
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 #if defined(ARM_MATH_NEON) && defined(__aarch64__)
31 #define GROUPOFROWS 8
32 #endif
33 
34 /**
35  * @ingroup groupMatrix
36  */
37 
38 /**
39  * @defgroup MatrixMult Matrix Multiplication
40  *
41  * Multiplies two matrices.
42  *
43  * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
44 
45  * Matrix multiplication is only defined if the number of columns of the
46  * first matrix equals the number of rows of the second matrix.
47  * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
48  * in an <code>M x P</code> matrix.
49  * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
50  * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
51  * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
52  */
53 
54 
55 /**
56  * @addtogroup MatrixMult
57  * @{
58  */
59 
60 /**
61  * @brief Floating-point matrix multiplication.
62  * @param[in]       *pSrcA points to the first input matrix structure
63  * @param[in]       *pSrcB points to the second input matrix structure
64  * @param[out]      *pDst points to output matrix structure
65  * @return     		The function returns either
66  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
67  */
68 
69 #if defined(ARM_MATH_NEON) && defined(__aarch64__)
arm_mat_mult_f64(const arm_matrix_instance_f64 * pSrcA,const arm_matrix_instance_f64 * pSrcB,arm_matrix_instance_f64 * pDst)70 ARM_DSP_ATTRIBUTE arm_status arm_mat_mult_f64(
71   const arm_matrix_instance_f64 * pSrcA,
72   const arm_matrix_instance_f64 * pSrcB,
73   arm_matrix_instance_f64 * pDst)
74 {
75   float64_t *pIn1 = pSrcA->pData;                /* input data matrix pointer A */
76   float64_t *pIn2 = pSrcB->pData;                /* input data matrix pointer B */
77   float64_t *pInA = pSrcA->pData;                /* input data matrix pointer A  */
78   float64_t *pOut = pDst->pData;                 /* output data matrix pointer */
79   float64_t *px;                                 /* Temporary output data matrix pointer */
80   float64_t sum;                                 /* Accumulator */
81   uint32_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A */
82   uint32_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
83   uint32_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
84 
85 
86   uint32_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt;      /* loop counters */
87   arm_status status;                             /* status of matrix multiplication */
88 
89   float64x2_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
90   float64x2_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
91   float64_t *pIn1B = pSrcA->pData;
92   float64_t *pIn1C = pSrcA->pData;
93   float64_t *pIn1D = pSrcA->pData;
94   float64_t *pIn1E = pSrcA->pData;
95   float64_t *pIn1F = pSrcA->pData;
96   float64_t *pIn1G = pSrcA->pData;
97   float64_t *pIn1H = pSrcA->pData;
98 
99   float64_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH;                                 /* Temporary output data matrix pointer */
100   float64_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
101 
102 #ifdef ARM_MATH_MATRIX_CHECK
103 
104   /* Check for matrix mismatch condition */
105   if ((pSrcA->numCols != pSrcB->numRows) ||
106      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
107   {
108     /* Set status as ARM_MATH_SIZE_MISMATCH */
109     status = ARM_MATH_SIZE_MISMATCH;
110   }
111   else
112 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
113   {
114     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
115     /* Row loop */
116     rowCnt = row >> 3;
117 
118     while(rowCnt > 0)
119     {
120       /* Output pointer is set to starting address of the row being processed */
121       px = pOut + GROUPOFROWS*i;
122       pxB = px + numColsB;
123       pxC = px + 2*numColsB;
124       pxD = px + 3*numColsB;
125       pxE = px + 4*numColsB;
126       pxF = px + 5*numColsB;
127       pxG = px + 6*numColsB;
128       pxH = px + 7*numColsB;
129 
130       /* For every row wise process, the column loop counter is to be initiated */
131       col = numColsB;
132 
133       /* For every row wise process, the pIn2 pointer is set
134        ** to the starting address of the pSrcB data */
135       pIn2 = pSrcB->pData;
136 
137       j = 0U;
138 
139       /* Column loop */
140       do
141       {
142         /* Set the variable sum, that acts as accumulator, to zero */
143         sum0 = 0.0;
144         sum1 = 0.0;
145         sum2 = 0.0;
146         sum3 = 0.0;
147         sum4 = 0.0;
148         sum5 = 0.0;
149         sum6 = 0.0;
150         sum7 = 0.0;
151 
152         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
153         pIn1 = pInA;
154         pIn1B = pIn1 + numColsA;
155         pIn1C = pIn1 + 2*numColsA;
156         pIn1D = pIn1 + 3*numColsA;
157         pIn1E = pIn1 + 4*numColsA;
158         pIn1F = pIn1 + 5*numColsA;
159         pIn1G = pIn1 + 6*numColsA;
160         pIn1H = pIn1 + 7*numColsA;
161 
162         acc0 = vdupq_n_f64(0.0);
163         acc1 = vdupq_n_f64(0.0);
164         acc2 = vdupq_n_f64(0.0);
165         acc3 = vdupq_n_f64(0.0);
166         acc4 = vdupq_n_f64(0.0);
167         acc5 = vdupq_n_f64(0.0);
168         acc6 = vdupq_n_f64(0.0);
169         acc7 = vdupq_n_f64(0.0);
170 
171         /* Compute 2 MACs simultaneously. */
172         colCnt = numColsA >> 1U;
173 
174         /* Matrix multiplication */
175         while (colCnt > 0U)
176         {
177           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
178           a0V = vld1q_f64(pIn1);
179           a1V = vld1q_f64(pIn1B);
180           a2V = vld1q_f64(pIn1C);
181           a3V = vld1q_f64(pIn1D);
182           a4V = vld1q_f64(pIn1E);
183           a5V = vld1q_f64(pIn1F);
184           a6V = vld1q_f64(pIn1G);
185           a7V = vld1q_f64(pIn1H);
186 
187           pIn1 += 2;
188           pIn1B += 2;
189           pIn1C += 2;
190           pIn1D += 2;
191           pIn1E += 2;
192           pIn1F += 2;
193           pIn1G += 2;
194           pIn1H += 2;
195 
196           temp = vsetq_lane_f64(*pIn2,temp,0);
197           pIn2 += numColsB;
198           temp = vsetq_lane_f64(*pIn2,temp,1);
199           pIn2 += numColsB;
200 
201 
202           acc0 = vmlaq_f64(acc0,a0V,temp);
203           acc1 = vmlaq_f64(acc1,a1V,temp);
204           acc2 = vmlaq_f64(acc2,a2V,temp);
205           acc3 = vmlaq_f64(acc3,a3V,temp);
206           acc4 = vmlaq_f64(acc4,a4V,temp);
207           acc5 = vmlaq_f64(acc5,a5V,temp);
208           acc6 = vmlaq_f64(acc6,a6V,temp);
209           acc7 = vmlaq_f64(acc7,a7V,temp);
210 
211           /* Decrement the loop count */
212           colCnt--;
213         }
214 
215         sum0 += vaddvq_f64(acc0);
216         sum1 += vaddvq_f64(acc1);
217         sum2 += vaddvq_f64(acc2);
218         sum3 += vaddvq_f64(acc3);
219         sum4 += vaddvq_f64(acc4);
220         sum5 += vaddvq_f64(acc5);
221         sum6 += vaddvq_f64(acc6);
222         sum7 += vaddvq_f64(acc7);
223 
224         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
225          ** No loop unrolling is used. */
226         colCnt = numColsA & 1;
227 
228         while (colCnt > 0U)
229         {
230           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
231           sum0 += *pIn1++ * (*pIn2);
232           sum1 += *pIn1B++ * (*pIn2);
233           sum2 += *pIn1C++ * (*pIn2);
234           sum3 += *pIn1D++ * (*pIn2);
235           sum4 += *pIn1E++ * (*pIn2);
236           sum5 += *pIn1F++ * (*pIn2);
237           sum6 += *pIn1G++ * (*pIn2);
238           sum7 += *pIn1H++ * (*pIn2);
239           pIn2 += numColsB;
240 
241           /* Decrement the loop counter */
242           colCnt--;
243         }
244 
245         /* Store the result in the destination buffer */
246         *px++ = sum0;
247         *pxB++ = sum1;
248         *pxC++ = sum2;
249         *pxD++ = sum3;
250         *pxE++ = sum4;
251         *pxF++ = sum5;
252         *pxG++ = sum6;
253         *pxH++ = sum7;
254 
255         /* Update the pointer pIn2 to point to the  starting address of the next column */
256         j++;
257         pIn2 = pSrcB->pData + j;
258 
259         /* Decrement the column loop counter */
260         col--;
261 
262       } while (col > 0U);
263 
264       /* Update the pointer pInA to point to the  starting address of the next row */
265       i = i + numColsB;
266       pInA = pInA + GROUPOFROWS*numColsA;
267 
268       /* Decrement the row loop counter */
269       rowCnt--;
270     }
271 
272     /*
273 
274     i was the index of a group of rows computed by previous loop.
275     Now i is the index of a row since below code is computing row per row
276     and no more group of row per group of rows.
277 
278     */
279 
280     i = GROUPOFROWS*i;
281     rowCnt = row & 7;
282 
283     while(rowCnt > 0)
284     {
285       /* Output pointer is set to starting address of the row being processed */
286       px = pOut + i;
287 
288       /* For every row wise process, the column loop counter is to be initiated */
289       col = numColsB;
290 
291       /* For every row wise process, the pIn2 pointer is set
292        ** to the starting address of the pSrcB data */
293       pIn2 = pSrcB->pData;
294 
295       j = 0U;
296 
297       /* Column loop */
298       do
299       {
300         /* Set the variable sum, that acts as accumulator, to zero */
301         sum = 0.0;
302 
303         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
304         pIn1 = pInA;
305 
306         acc0 = vdupq_n_f64(0.0);
307 
308         /* Compute 4 MACs simultaneously. */
309         colCnt = numColsA >> 1U;
310 
311         /* Matrix multiplication   */
312         while (colCnt > 0U)
313         {
314           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
315           a0V = vld1q_f64(pIn1);  // load & separate real/imag pSrcA (de-interleave 2)
316           pIn1 += 2;
317 
318           temp = vsetq_lane_f64(*pIn2,temp,0);
319           pIn2 += numColsB;
320           temp = vsetq_lane_f64(*pIn2,temp,1);
321           pIn2 += numColsB;
322 
323 
324           acc0 = vmlaq_f64(acc0,a0V,temp);
325 
326           /* Decrement the loop count */
327           colCnt--;
328         }
329 
330         //accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
331         sum += vaddvq_f64(acc0);
332 
333         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
334          ** No loop unrolling is used. */
335         colCnt = numColsA % 0x2U;
336 
337         while (colCnt > 0U)
338         {
339           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
340           sum += *pIn1++ * (*pIn2);
341           pIn2 += numColsB;
342 
343           /* Decrement the loop counter */
344           colCnt--;
345         }
346 
347         /* Store the result in the destination buffer */
348         *px++ = sum;
349 
350         /* Update the pointer pIn2 to point to the  starting address of the next column */
351         j++;
352         pIn2 = pSrcB->pData + j;
353 
354         /* Decrement the column loop counter */
355         col--;
356 
357       } while (col > 0U);
358 
359 
360       /* Update the pointer pInA to point to the  starting address of the next row */
361       i = i + numColsB;
362       pInA = pInA + numColsA;
363 
364       /* Decrement the row loop counter */
365       rowCnt--;
366 
367     }
368     /* Set status as ARM_MATH_SUCCESS */
369     status = ARM_MATH_SUCCESS;
370   }
371 
372   /* Return to application */
373   return (status);
374 }
375 #else
arm_mat_mult_f64(const arm_matrix_instance_f64 * pSrcA,const arm_matrix_instance_f64 * pSrcB,arm_matrix_instance_f64 * pDst)376 ARM_DSP_ATTRIBUTE arm_status arm_mat_mult_f64(
377   const arm_matrix_instance_f64 * pSrcA,
378   const arm_matrix_instance_f64 * pSrcB,
379         arm_matrix_instance_f64 * pDst)
380 {
381   float64_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
382   float64_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
383   float64_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
384   float64_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
385   float64_t *pOut = pDst->pData;                 /* Output data matrix pointer */
386   float64_t *px;                                 /* Temporary output data matrix pointer */
387   float64_t sum;                                 /* Accumulator */
388   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
389   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
390   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
391   uint64_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
392   arm_status status;                             /* Status of matrix multiplication */
393 
394 #ifdef ARM_MATH_MATRIX_CHECK
395 
396   /* Check for matrix mismatch condition */
397   if ((pSrcA->numCols != pSrcB->numRows) ||
398       (pSrcA->numRows != pDst->numRows)  ||
399       (pSrcB->numCols != pDst->numCols)    )
400   {
401     /* Set status as ARM_MATH_SIZE_MISMATCH */
402     status = ARM_MATH_SIZE_MISMATCH;
403   }
404   else
405 
406 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
407 
408   {
409     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
410     /* row loop */
411     do
412     {
413       /* Output pointer is set to starting address of row being processed */
414       px = pOut + i;
415 
416       /* For every row wise process, column loop counter is to be initiated */
417       col = numColsB;
418 
419       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
420       pIn2 = pSrcB->pData;
421 
422       /* column loop */
423       do
424       {
425         /* Set the variable sum, that acts as accumulator, to zero */
426         sum = 0.0;
427 
428         /* Initialize pointer pIn1 to point to starting address of column being processed */
429         pIn1 = pInA;
430 
431 #if defined (ARM_MATH_LOOPUNROLL)
432 
433         /* Loop unrolling: Compute 4 MACs at a time. */
434         colCnt = numColsA >> 2U;
435 
436         /* matrix multiplication */
437         while (colCnt > 0U)
438         {
439           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
440 
441           /* Perform the multiply-accumulates */
442           sum += *pIn1++ * *pIn2;
443           pIn2 += numColsB;
444 
445           sum += *pIn1++ * *pIn2;
446           pIn2 += numColsB;
447 
448           sum += *pIn1++ * *pIn2;
449           pIn2 += numColsB;
450 
451           sum += *pIn1++ * *pIn2;
452           pIn2 += numColsB;
453 
454           /* Decrement loop counter */
455           colCnt--;
456         }
457 
458         /* Loop unrolling: Compute remaining MACs */
459         colCnt = numColsA % 0x4U;
460 
461 #else
462 
463         /* Initialize cntCnt with number of columns */
464         colCnt = numColsA;
465 
466 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
467 
468         while (colCnt > 0U)
469         {
470           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
471 
472           /* Perform the multiply-accumulates */
473           sum += *pIn1++ * *pIn2;
474           pIn2 += numColsB;
475 
476           /* Decrement loop counter */
477           colCnt--;
478         }
479 
480         /* Store result in destination buffer */
481         *px++ = sum;
482 
483         /* Decrement column loop counter */
484         col--;
485 
486         /* Update pointer pIn2 to point to starting address of next column */
487         pIn2 = pInB + (numColsB - col);
488 
489       } while (col > 0U);
490 
491       /* Update pointer pInA to point to starting address of next row */
492       i = i + numColsB;
493       pInA = pInA + numColsA;
494 
495       /* Decrement row loop counter */
496       row--;
497 
498     } while (row > 0U);
499 
500     /* Set status as ARM_MATH_SUCCESS */
501     status = ARM_MATH_SUCCESS;
502   }
503 
504   /* Return to application */
505   return (status);
506 }
507 #endif
508 
509 /**
510  * @} end of MatrixMult group
511  */
512