1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_fast_q31.c
4  * Description:  Q31 matrix multiplication (fast variant)
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup MatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q31 matrix multiplication (fast variant).
42   @param[in]     pSrcA      points to the first input matrix structure
43   @param[in]     pSrcB      points to the second input matrix structure
44   @param[out]    pDst       points to output matrix structure
45   @return        execution status
46                    - \ref ARM_MATH_SUCCESS       : Operation successful
47                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
48 
49   @par           Scaling and Overflow Behavior
50                    The difference between the function \ref arm_mat_mult_q31() and this fast variant is that
51                    the fast variant use a 32-bit rather than a 64-bit accumulator.
52                    The result of each 1.31 x 1.31 multiplication is truncated to
53                    2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
54                    format. Finally, the accumulator is saturated and converted to a 1.31 result.
55   @par
56                    The fast version has the same overflow behavior as the standard version but provides
57                    less precision since it discards the low 32 bits of each multiplication result.
58                    In order to avoid overflows completely the input signals must be scaled down.
59                    Scale down one of the input matrices by log2(numColsA) bits to avoid overflows,
60                    as a total of numColsA additions are computed internally for each output element.
61   @remark
62                    Refer to \ref arm_mat_mult_q31() for a slower implementation of this function
63                    which uses 64-bit accumulation to provide higher precision.
64  */
65 
arm_mat_mult_fast_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)66 arm_status arm_mat_mult_fast_q31(
67   const arm_matrix_instance_q31 * pSrcA,
68   const arm_matrix_instance_q31 * pSrcB,
69         arm_matrix_instance_q31 * pDst)
70 {
71   q31_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A */
72   q31_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B */
73   q31_t *pInA2;
74   q31_t *px;                                     /* Temporary output data matrix pointer */
75   q31_t *px2;
76   q31_t sum1, sum2, sum3, sum4;                  /* Accumulator */
77   q31_t inA1, inA2, inB1, inB2;
78   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
79   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
80   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
81   uint32_t col, i = 0U, j, row = numRowsA, colCnt;  /* Loop counters */
82   arm_status status;                             /* Status of matrix multiplication */
83 
84 
85 #ifdef ARM_MATH_MATRIX_CHECK
86 
87   /* Check for matrix mismatch condition */
88   if ((pSrcA->numCols != pSrcB->numRows) ||
89       (pSrcA->numRows != pDst->numRows)  ||
90       (pSrcB->numCols != pDst->numCols)    )
91   {
92     /* Set status as ARM_MATH_SIZE_MISMATCH */
93     status = ARM_MATH_SIZE_MISMATCH;
94   }
95   else
96 
97 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
98 
99   {
100     px = pDst->pData;
101 
102     row = row >> 1U;
103     px2 = px + numColsB;
104 
105     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
106     /* row loop */
107     while (row > 0U)
108     {
109       /* For every row wise process, column loop counter is to be initiated */
110       col = numColsB;
111 
112       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
113       pInB = pSrcB->pData;
114 
115       j = 0U;
116 
117       col = col >> 1U;
118 
119       /* column loop */
120       while (col > 0U)
121       {
122         /* Set the variable sum, that acts as accumulator, to zero */
123         sum1 = 0;
124         sum2 = 0;
125         sum3 = 0;
126         sum4 = 0;
127 
128         /* Initiate data pointers */
129         pInA = pSrcA->pData + i;
130         pInB = pSrcB->pData + j;
131         pInA2 = pInA + numColsA;
132 
133         colCnt = numColsA;
134 
135         /* matrix multiplication */
136         while (colCnt > 0U)
137         {
138           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
139 
140           inA1 = *pInA++;
141           inB1 = pInB[0];
142           inA2 = *pInA2++;
143           inB2 = pInB[1];
144           pInB += numColsB;
145 
146 #if defined (ARM_MATH_DSP)
147           sum1 = __SMMLA(inA1, inB1, sum1);
148           sum2 = __SMMLA(inA1, inB2, sum2);
149           sum3 = __SMMLA(inA2, inB1, sum3);
150           sum4 = __SMMLA(inA2, inB2, sum4);
151 #else
152           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
153           sum2 = (q31_t) ((((q63_t) sum2 << 32) + ((q63_t) inA1 * inB2)) >> 32);
154           sum3 = (q31_t) ((((q63_t) sum3 << 32) + ((q63_t) inA2 * inB1)) >> 32);
155           sum4 = (q31_t) ((((q63_t) sum4 << 32) + ((q63_t) inA2 * inB2)) >> 32);
156 #endif
157 
158           /* Decrement loop counter */
159           colCnt--;
160         }
161 
162         /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
163         *px++  = sum1 << 1;
164         *px++  = sum2 << 1;
165         *px2++ = sum3 << 1;
166         *px2++ = sum4 << 1;
167 
168         j += 2;
169 
170         /* Decrement column loop counter */
171         col--;
172       }
173 
174       i = i + (numColsA << 1U);
175       px  = px2 + (numColsB & 1U);
176       px2 = px  +  numColsB;
177 
178       /* Decrement row loop counter */
179       row--;
180     }
181 
182     /* Compute any remaining odd row/column below */
183 
184     /* Compute remaining output column */
185     if (numColsB & 1U) {
186 
187       /* Avoid redundant computation of last element */
188       row = numRowsA & (~1U);
189 
190       /* Point to remaining unfilled column in output matrix */
191       px = pDst->pData + numColsB-1;
192       pInA = pSrcA->pData;
193 
194       /* row loop */
195       while (row > 0)
196       {
197 
198         /* point to last column in matrix B */
199         pInB  = pSrcB->pData + numColsB-1;
200 
201         /* Set variable sum1, that acts as accumulator, to zero */
202         sum1  = 0;
203 
204 #if defined (ARM_MATH_LOOPUNROLL)
205 
206         /* Loop unrolling: Compute 4 columns at a time. */
207         colCnt = numColsA >> 2U;
208 
209         /* matrix multiplication */
210         while (colCnt > 0U)
211         {
212 #if defined (ARM_MATH_DSP)
213           sum1 = __SMMLA(*pInA++, *pInB, sum1);
214 #else
215           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
216 #endif
217           pInB += numColsB;
218 
219 #if defined (ARM_MATH_DSP)
220           sum1 = __SMMLA(*pInA++, *pInB, sum1);
221 #else
222           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
223 #endif
224           pInB += numColsB;
225 
226 #if defined (ARM_MATH_DSP)
227           sum1 = __SMMLA(*pInA++, *pInB, sum1);
228 #else
229           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
230 #endif
231           pInB += numColsB;
232 
233 #if defined (ARM_MATH_DSP)
234           sum1 = __SMMLA(*pInA++, *pInB, sum1);
235 #else
236           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
237 #endif
238           pInB += numColsB;
239 
240           /* Decrement loop counter */
241           colCnt--;
242         }
243 
244         /* Loop unrolling: Compute remaining column */
245         colCnt = numColsA % 4U;
246 
247 #else
248 
249         /* Initialize colCnt with number of columns */
250         colCnt = numColsA;
251 
252 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
253 
254         while (colCnt > 0U) {
255 #if defined (ARM_MATH_DSP)
256           sum1 = __SMMLA(*pInA++, *pInB, sum1);
257 #else
258           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
259 #endif
260           pInB += numColsB;
261 
262           colCnt--;
263         }
264 
265         /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
266         *px = sum1 << 1;
267         px += numColsB;
268 
269         /* Decrement row loop counter */
270         row--;
271       }
272     }
273 
274     /* Compute remaining output row */
275     if (numRowsA & 1U) {
276 
277       /* point to last row in output matrix */
278       px = pDst->pData + (numColsB) * (numRowsA-1);
279 
280       col = numColsB;
281       i = 0U;
282 
283       /* col loop */
284       while (col > 0)
285       {
286 
287         /* point to last row in matrix A */
288         pInA = pSrcA->pData + (numRowsA-1) * numColsA;
289         pInB  = pSrcB->pData + i;
290 
291         /* Set variable sum1, that acts as accumulator, to zero */
292         sum1  = 0;
293 
294 #if defined (ARM_MATH_LOOPUNROLL)
295 
296         /* Loop unrolling: Compute 4 columns at a time. */
297         colCnt = numColsA >> 2U;
298 
299         /* matrix multiplication */
300         while (colCnt > 0U)
301         {
302           inA1 = *pInA++;
303           inA2 = *pInA++;
304           inB1 = *pInB;
305           pInB += numColsB;
306           inB2 = *pInB;
307           pInB += numColsB;
308 #if defined (ARM_MATH_DSP)
309           sum1 = __SMMLA(inA1, inB1, sum1);
310           sum1 = __SMMLA(inA2, inB2, sum1);
311 #else
312           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
313           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA2 * inB2)) >> 32);
314 #endif
315 
316           inA1 = *pInA++;
317           inA2 = *pInA++;
318           inB1 = *pInB;
319           pInB += numColsB;
320           inB2 = *pInB;
321           pInB += numColsB;
322 #if defined (ARM_MATH_DSP)
323           sum1 = __SMMLA(inA1, inB1, sum1);
324           sum1 = __SMMLA(inA2, inB2, sum1);
325 #else
326           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA1 * inB1)) >> 32);
327           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) inA2 * inB2)) >> 32);
328 #endif
329 
330           /* Decrement loop counter */
331           colCnt--;
332         }
333 
334         /* Loop unrolling: Compute remaining column */
335         colCnt = numColsA % 4U;
336 
337 #else
338 
339         /* Initialize colCnt with number of columns */
340         colCnt = numColsA;
341 
342 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
343 
344         while (colCnt > 0U) {
345 #if defined (ARM_MATH_DSP)
346           sum1 = __SMMLA(*pInA++, *pInB, sum1);
347 #else
348           sum1 = (q31_t) ((((q63_t) sum1 << 32) + ((q63_t) *pInA++ * *pInB)) >> 32);
349 #endif
350           pInB += numColsB;
351 
352           colCnt--;
353         }
354 
355         /* Saturate and store the result in the destination buffer */
356         *px++ = sum1 << 1;
357         i++;
358 
359         /* Decrement col loop counter */
360         col--;
361       }
362     }
363 
364     /* Set status as ARM_MATH_SUCCESS */
365     status = ARM_MATH_SUCCESS;
366   }
367 
368   /* Return to application */
369   return (status);
370 }
371 
372 /**
373   @} end of MatrixMult group
374  */
375