1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_fast_q15.c
4  * Description:  Q15 matrix multiplication (fast variant)
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup MatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q15 matrix multiplication (fast variant).
42   @param[in]     pSrcA      points to the first input matrix structure
43   @param[in]     pSrcB      points to the second input matrix structure
44   @param[out]    pDst       points to output matrix structure
45   @param[in]     pState     points to the array for storing intermediate results
46   @return        execution status
47                    - \ref ARM_MATH_SUCCESS       : Operation successful
48                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49 
50   @par           Scaling and Overflow Behavior
51                    The difference between the function \ref arm_mat_mult_q15() and this fast variant is that
52                    the fast variant use a 32-bit rather than a 64-bit accumulator.
53                    The result of each 1.15 x 1.15 multiplication is truncated to
54                    2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
55                    format. Finally, the accumulator is saturated and converted to a 1.15 result.
56   @par
57                    The fast version has the same overflow behavior as the standard version but provides
58                    less precision since it discards the low 16 bits of each multiplication result.
59                    In order to avoid overflows completely the input signals must be scaled down.
60                    Scale down one of the input matrices by log2(numColsA) bits to avoid overflows,
61                    as a total of numColsA additions are computed internally for each output element.
62   @remark
63                    Refer to \ref arm_mat_mult_q15() for a slower implementation of this function
64                    which uses 64-bit accumulation to provide higher precision.
65  */
66 
arm_mat_mult_fast_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)67 arm_status arm_mat_mult_fast_q15(
68   const arm_matrix_instance_q15 * pSrcA,
69   const arm_matrix_instance_q15 * pSrcB,
70         arm_matrix_instance_q15 * pDst,
71         q15_t                   * pState)
72 {
73         q31_t sum;                                     /* Accumulator */
74         q15_t *pSrcBT = pState;                        /* Input data matrix pointer for transpose */
75         q15_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A of Q15 type */
76         q15_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B of Q15 type */
77         q15_t *px;                                     /* Temporary output data matrix pointer */
78         uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
79         uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
80         uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
81         uint16_t numRowsB = pSrcB->numRows;            /* Number of rows of input matrix B */
82         uint32_t col, i = 0U, row = numRowsB, colCnt;  /* Loop counters */
83         arm_status status;                             /* Status of matrix multiplication */
84 
85 #if defined (ARM_MATH_DSP)
86         q31_t in;                                      /* Temporary variable to hold the input value */
87         q31_t inA1, inB1, inA2, inB2;
88         q31_t sum2, sum3, sum4;
89         q15_t *pInA2, *pInB2, *px2;
90         uint32_t j = 0;
91 #else
92         q15_t in;                                      /* Temporary variable to hold the input value */
93         q15_t inA1, inB1, inA2, inB2;
94 #endif /* #if defined (ARM_MATH_DSP) */
95 
96 #ifdef ARM_MATH_MATRIX_CHECK
97 
98   /* Check for matrix mismatch condition */
99   if ((pSrcA->numCols != pSrcB->numRows) ||
100       (pSrcA->numRows != pDst->numRows)  ||
101       (pSrcB->numCols != pDst->numCols)    )
102   {
103     /* Set status as ARM_MATH_SIZE_MISMATCH */
104     status = ARM_MATH_SIZE_MISMATCH;
105   }
106   else
107 
108 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
109 
110   {
111     /* Matrix transpose */
112     do
113     {
114       /* The pointer px is set to starting address of column being processed */
115       px = pSrcBT + i;
116 
117       /* Apply loop unrolling and exchange columns with row elements */
118       col = numColsB >> 2U;
119 
120       /* First part of the processing with loop unrolling.  Compute 4 outputs at a time.
121        ** a second loop below computes the remaining 1 to 3 samples. */
122       while (col > 0U)
123       {
124 
125 #if defined (ARM_MATH_DSP)
126 
127         /* Read two elements from row */
128         in = read_q15x2_ia ((q15_t **) &pInB);
129 
130         /* Unpack and store one element in destination */
131 #ifndef ARM_MATH_BIG_ENDIAN
132         *px = (q15_t) in;
133 #else
134         *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
135 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
136 
137         /* Update pointer px to point to next row of transposed matrix */
138         px += numRowsB;
139 
140         /* Unpack and store second element in destination */
141 #ifndef ARM_MATH_BIG_ENDIAN
142         *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
143 #else
144         *px = (q15_t) in;
145 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
146 
147         /* Update pointer px to point to next row of transposed matrix */
148         px += numRowsB;
149 
150         in = read_q15x2_ia ((q15_t **) &pInB);
151 #ifndef ARM_MATH_BIG_ENDIAN
152         *px = (q15_t) in;
153 #else
154         *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
155 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
156         px += numRowsB;
157 
158 #ifndef ARM_MATH_BIG_ENDIAN
159         *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
160 #else
161         *px = (q15_t) in;
162 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
163         px += numRowsB;
164 
165 #else /* #if defined (ARM_MATH_DSP) */
166 
167         /* Read one element from row */
168         in = *pInB++;
169 
170         /* Store one element in destination */
171         *px = in;
172 
173         /* Update pointer px to point to next row of transposed matrix */
174         px += numRowsB;
175 
176         in = *pInB++;
177         *px = in;
178         px += numRowsB;
179 
180         in = *pInB++;
181         *px = in;
182         px += numRowsB;
183 
184         in = *pInB++;
185         *px = in;
186         px += numRowsB;
187 
188 #endif /* #if defined (ARM_MATH_DSP) */
189 
190         /* Decrement column loop counter */
191         col--;
192       }
193 
194       /* If the columns of pSrcB is not a multiple of 4, compute any remaining output samples here.
195        ** No loop unrolling is used. */
196       col = numColsB % 0x4U;
197 
198       while (col > 0U)
199       {
200         /* Read and store input element in destination */
201         *px = *pInB++;
202 
203         /* Update pointer px to point to next row of transposed matrix */
204         px += numRowsB;
205 
206         /* Decrement column loop counter */
207         col--;
208       }
209 
210       i++;
211 
212       /* Decrement row loop counter */
213       row--;
214 
215     } while (row > 0U);
216 
217     /* Reset variables for usage in following multiplication process */
218     row = numRowsA;
219     i = 0U;
220     px = pDst->pData;
221 
222 #if defined (ARM_MATH_DSP)
223     /* Process two rows from matrix A at a time and output two rows at a time */
224     row = row >> 1U;
225     px2 = px + numColsB;
226 #endif
227 
228     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
229     /* row loop */
230     while (row > 0U)
231     {
232       /* For every row wise process, column loop counter is to be initiated */
233       col = numColsB;
234 
235       /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
236       pInB = pSrcBT;
237 
238 #if defined (ARM_MATH_DSP)
239       /* Process two (transposed) columns from matrix B at a time */
240       col = col >> 1U;
241       j = 0;
242 #endif
243 
244       /* column loop */
245       while (col > 0U)
246       {
247         /* Set variable sum, that acts as accumulator, to zero */
248         sum = 0;
249 
250         /* Initiate pointer pInA to point to starting address of column being processed */
251         pInA = pSrcA->pData + i;
252 
253 #if defined (ARM_MATH_DSP)
254         sum2 = 0;
255         sum3 = 0;
256         sum4 = 0;
257         pInB  = pSrcBT + j;
258         pInA2 = pInA + numColsA;
259         pInB2 = pInB + numRowsB;
260 
261         /* Read in two elements at once - allows dual MAC instruction */
262         colCnt = numColsA >> 1U;
263 #else
264         colCnt = numColsA >> 2U;
265 #endif
266 
267         /* matrix multiplication */
268         while (colCnt > 0U)
269         {
270           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
271 
272 #if defined (ARM_MATH_DSP)
273           /* read real and imag values from pSrcA and pSrcB buffer */
274           inA1 = read_q15x2_ia ((q15_t **) &pInA);
275           inB1 = read_q15x2_ia ((q15_t **) &pInB);
276 
277           inA2 = read_q15x2_ia ((q15_t **) &pInA2);
278           inB2 = read_q15x2_ia ((q15_t **) &pInB2);
279 
280           /* Multiply and Accumulates */
281           sum  = __SMLAD(inA1, inB1, sum);
282           sum2 = __SMLAD(inA1, inB2, sum2);
283           sum3 = __SMLAD(inA2, inB1, sum3);
284           sum4 = __SMLAD(inA2, inB2, sum4);
285 #else
286           /* read real and imag values from pSrcA and pSrcB buffer */
287           inA1 = *pInA++;
288           inB1 = *pInB++;
289           /* Multiply and Accumulates */
290           sum += inA1 * inB1;
291 
292           inA2 = *pInA++;
293           inB2 = *pInB++;
294           sum += inA2 * inB2;
295 
296           inA1 = *pInA++;
297           inB1 = *pInB++;
298           sum += inA1 * inB1;
299 
300           inA2 = *pInA++;
301           inB2 = *pInB++;
302           sum += inA2 * inB2;
303 #endif /* #if defined (ARM_MATH_DSP) */
304 
305           /* Decrement loop counter */
306           colCnt--;
307         }
308 
309         /* process odd column samples */
310 #if defined (ARM_MATH_DSP)
311         if (numColsA & 1U) {
312           inA1 = *pInA++;
313           inB1 = *pInB++;
314           inA2 = *pInA2++;
315           inB2 = *pInB2++;
316           sum  += inA1 * inB1;
317           sum2 += inA1 * inB2;
318           sum3 += inA2 * inB1;
319           sum4 += inA2 * inB2;
320         }
321 #else
322         colCnt = numColsA % 0x4U;
323 
324         while (colCnt > 0U)
325         {
326           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
327           sum += (q31_t) *pInA++ * *pInB++;
328 
329           /* Decrement loop counter */
330           colCnt--;
331         }
332 #endif /* #if defined (ARM_MATH_DSP) */
333 
334         /* Saturate and store result in destination buffer */
335         *px++  = (q15_t) (sum >> 15);
336 
337 #if defined (ARM_MATH_DSP)
338         *px++  = (q15_t) (sum2 >> 15);
339         *px2++ = (q15_t) (sum3 >> 15);
340         *px2++ = (q15_t) (sum4 >> 15);
341         j += numRowsB * 2;
342 #endif
343 
344         /* Decrement column loop counter */
345         col--;
346 
347       }
348 
349       i = i + numColsA;
350 
351 #if defined (ARM_MATH_DSP)
352       i = i + numColsA;
353       px = px2 + (numColsB & 1U);
354       px2 = px + numColsB;
355 #endif
356 
357       /* Decrement row loop counter */
358       row--;
359 
360     }
361 
362     /* Compute any remaining odd row/column below */
363 
364 #if defined (ARM_MATH_DSP)
365 
366     /* Compute remaining output column */
367     if (numColsB & 1U) {
368 
369       /* Avoid redundant computation of last element */
370       row = numRowsA & (~0x1);
371 
372       /* Point to remaining unfilled column in output matrix */
373       px = pDst->pData + numColsB-1;
374       pInA = pSrcA->pData;
375 
376       /* row loop */
377       while (row > 0)
378       {
379 
380         /* point to last column in matrix B */
381         pInB  = pSrcBT + numRowsB * (numColsB-1);
382 
383         /* Set variable sum, that acts as accumulator, to zero */
384         sum  = 0;
385 
386         /* Compute 4 columns at once */
387         colCnt = numColsA >> 2U;
388 
389         /* matrix multiplication */
390         while (colCnt > 0U)
391         {
392           inA1 = read_q15x2_ia ((q15_t **) &pInA);
393           inA2 = read_q15x2_ia ((q15_t **) &pInA);
394           inB1 = read_q15x2_ia ((q15_t **) &pInB);
395           inB2 = read_q15x2_ia ((q15_t **) &pInB);
396 
397           sum  = __SMLAD(inA1, inB1, sum);
398           sum  = __SMLAD(inA2, inB2, sum);
399 
400           /* Decrement loop counter */
401           colCnt--;
402         }
403 
404         colCnt = numColsA & 3U;
405         while (colCnt > 0U) {
406           sum += (q31_t) (*pInA++) * (*pInB++);
407           colCnt--;
408         }
409 
410         /* Store result in destination buffer */
411         *px = (q15_t) (sum  >> 15);
412         px += numColsB;
413 
414         /* Decrement row loop counter */
415         row--;
416       }
417     }
418 
419     /* Compute remaining output row */
420     if (numRowsA & 1U) {
421 
422       /* point to last row in output matrix */
423       px = pDst->pData + (numColsB) * (numRowsA-1);
424 
425       pInB  = pSrcBT;
426       col = numColsB;
427       i = 0U;
428 
429       /* col loop */
430       while (col > 0)
431       {
432         /* point to last row in matrix A */
433         pInA = pSrcA->pData + (numRowsA-1) * numColsA;
434 
435         /* Set variable sum, that acts as accumulator, to zero */
436         sum  = 0;
437 
438         /* Compute 4 columns at once */
439         colCnt = numColsA >> 2U;
440 
441         /* matrix multiplication */
442         while (colCnt > 0U)
443         {
444           inA1 = read_q15x2_ia ((q15_t **) &pInA);
445           inA2 = read_q15x2_ia ((q15_t **) &pInA);
446           inB1 = read_q15x2_ia ((q15_t **) &pInB);
447           inB2 = read_q15x2_ia ((q15_t **) &pInB);
448 
449           sum  = __SMLAD(inA1, inB1, sum);
450           sum  = __SMLAD(inA2, inB2, sum);
451 
452           /* Decrement loop counter */
453           colCnt--;
454         }
455 
456         colCnt = numColsA % 4U;
457         while (colCnt > 0U) {
458           sum += (q31_t) (*pInA++) * (*pInB++);
459 
460           colCnt--;
461         }
462 
463         /* Store result in destination buffer */
464         *px++ = (q15_t) (sum  >> 15);
465 
466         /* Decrement column loop counter */
467         col--;
468       }
469     }
470 
471 #endif /* #if defined (ARM_MATH_DSP) */
472 
473     /* Set status as ARM_MATH_SUCCESS */
474     status = ARM_MATH_SUCCESS;
475   }
476 
477   /* Return to application */
478   return (status);
479 }
480 
481 /**
482   @} end of MatrixMult group
483  */
484