1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_f64.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32  * @ingroup groupMatrix
33  */
34 
35 /**
36  * @defgroup MatrixMult Matrix Multiplication
37  *
38  * Multiplies two matrices.
39  *
40  * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
41 
42  * Matrix multiplication is only defined if the number of columns of the
43  * first matrix equals the number of rows of the second matrix.
44  * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
45  * in an <code>M x P</code> matrix.
46  * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
47  * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
48  * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
49  */
50 
51 
52 /**
53  * @addtogroup MatrixMult
54  * @{
55  */
56 
57 /**
58  * @brief Floating-point matrix multiplication.
59  * @param[in]       *pSrcA points to the first input matrix structure
60  * @param[in]       *pSrcB points to the second input matrix structure
61  * @param[out]      *pDst points to output matrix structure
62  * @return     		The function returns either
63  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
64  */
65 
66 
arm_mat_mult_f64(const arm_matrix_instance_f64 * pSrcA,const arm_matrix_instance_f64 * pSrcB,arm_matrix_instance_f64 * pDst)67 arm_status arm_mat_mult_f64(
68   const arm_matrix_instance_f64 * pSrcA,
69   const arm_matrix_instance_f64 * pSrcB,
70         arm_matrix_instance_f64 * pDst)
71 {
72   float64_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
73   float64_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
74   float64_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
75   float64_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
76   float64_t *pOut = pDst->pData;                 /* Output data matrix pointer */
77   float64_t *px;                                 /* Temporary output data matrix pointer */
78   float64_t sum;                                 /* Accumulator */
79   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
80   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
81   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
82   uint64_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
83   arm_status status;                             /* Status of matrix multiplication */
84 
85 #ifdef ARM_MATH_MATRIX_CHECK
86 
87   /* Check for matrix mismatch condition */
88   if ((pSrcA->numCols != pSrcB->numRows) ||
89       (pSrcA->numRows != pDst->numRows)  ||
90       (pSrcB->numCols != pDst->numCols)    )
91   {
92     /* Set status as ARM_MATH_SIZE_MISMATCH */
93     status = ARM_MATH_SIZE_MISMATCH;
94   }
95   else
96 
97 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
98 
99   {
100     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
101     /* row loop */
102     do
103     {
104       /* Output pointer is set to starting address of row being processed */
105       px = pOut + i;
106 
107       /* For every row wise process, column loop counter is to be initiated */
108       col = numColsB;
109 
110       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
111       pIn2 = pSrcB->pData;
112 
113       /* column loop */
114       do
115       {
116         /* Set the variable sum, that acts as accumulator, to zero */
117         sum = 0.0f;
118 
119         /* Initialize pointer pIn1 to point to starting address of column being processed */
120         pIn1 = pInA;
121 
122 #if defined (ARM_MATH_LOOPUNROLL)
123 
124         /* Loop unrolling: Compute 4 MACs at a time. */
125         colCnt = numColsA >> 2U;
126 
127         /* matrix multiplication */
128         while (colCnt > 0U)
129         {
130           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
131 
132           /* Perform the multiply-accumulates */
133           sum += *pIn1++ * *pIn2;
134           pIn2 += numColsB;
135 
136           sum += *pIn1++ * *pIn2;
137           pIn2 += numColsB;
138 
139           sum += *pIn1++ * *pIn2;
140           pIn2 += numColsB;
141 
142           sum += *pIn1++ * *pIn2;
143           pIn2 += numColsB;
144 
145           /* Decrement loop counter */
146           colCnt--;
147         }
148 
149         /* Loop unrolling: Compute remaining MACs */
150         colCnt = numColsA % 0x4U;
151 
152 #else
153 
154         /* Initialize cntCnt with number of columns */
155         colCnt = numColsA;
156 
157 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
158 
159         while (colCnt > 0U)
160         {
161           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
162 
163           /* Perform the multiply-accumulates */
164           sum += *pIn1++ * *pIn2;
165           pIn2 += numColsB;
166 
167           /* Decrement loop counter */
168           colCnt--;
169         }
170 
171         /* Store result in destination buffer */
172         *px++ = sum;
173 
174         /* Decrement column loop counter */
175         col--;
176 
177         /* Update pointer pIn2 to point to starting address of next column */
178         pIn2 = pInB + (numColsB - col);
179 
180       } while (col > 0U);
181 
182       /* Update pointer pInA to point to starting address of next row */
183       i = i + numColsB;
184       pInA = pInA + numColsA;
185 
186       /* Decrement row loop counter */
187       row--;
188 
189     } while (row > 0U);
190 
191     /* Set status as ARM_MATH_SUCCESS */
192     status = ARM_MATH_SUCCESS;
193   }
194 
195   /* Return to application */
196   return (status);
197 }
198 
199 
200 /**
201  * @} end of MatrixMult group
202  */
203