1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_cmplx_mat_mult_q15.c
4  * Description:  Q15 complex matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup CmplxMatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q15 Complex matrix multiplication.
42   @param[in]     pSrcA      points to first input complex matrix structure
43   @param[in]     pSrcB      points to second input complex matrix structure
44   @param[out]    pDst       points to output complex matrix structure
45   @param[in]     pScratch   points to an array for storing intermediate results
46   @return        execution status
47                    - \ref ARM_MATH_SUCCESS       : Operation successful
48                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49 
50   @par           Conditions for optimum performance
51                    Input, output and state buffers should be aligned by 32-bit
52 
53   @par           Scaling and Overflow Behavior
54                    The function is implemented using an internal 64-bit accumulator. The inputs to the
55                    multiplications are in 1.15 format and multiplications yield a 2.30 result.
56                    The 2.30 intermediate results are accumulated in a 64-bit accumulator in 34.30 format.
57                    This approach provides 33 guard bits and there is no risk of overflow. The 34.30 result is then
58                    truncated to 34.15 format by discarding the low 15 bits and then saturated to 1.15 format.
59  */
60 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
61 
62 #define MVE_ASRL_SAT16(acc, shift)          ((sqrshrl_sat48(acc, -(32-shift)) >> 32) & 0xffffffff)
63 
arm_mat_cmplx_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pScratch)64 arm_status arm_mat_cmplx_mult_q15(
65   const arm_matrix_instance_q15 * pSrcA,
66   const arm_matrix_instance_q15 * pSrcB,
67         arm_matrix_instance_q15 * pDst,
68         q15_t                   * pScratch)
69 {
70     q15_t const *pInA = (q15_t const *) pSrcA->pData;   /* input data matrix pointer A of Q15 type */
71     q15_t const *pInB = (q15_t const *) pSrcB->pData;   /* input data matrix pointer B of Q15 type */
72     q15_t const *pInB2;
73     q15_t       *px;               /* Temporary output data matrix pointer */
74     uint32_t     numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
75     uint32_t     numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
76     uint32_t     numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
77     uint32_t     numRowsB = pSrcB->numRows;    /* number of rows of input matrix A    */
78     uint32_t     col, i = 0u, j, row = numRowsB;   /* loop counters */
79     uint32_t  blkCnt;           /* loop counters */
80     uint16x8_t vecOffs, vecColBOffs;
81     arm_status status;                             /* Status of matrix multiplication */
82     (void)pScratch;
83 
84 #ifdef ARM_MATH_MATRIX_CHECK
85 
86   /* Check for matrix mismatch condition */
87   if ((pSrcA->numCols != pSrcB->numRows) ||
88       (pSrcA->numRows != pDst->numRows)  ||
89       (pSrcB->numCols != pDst->numCols)    )
90   {
91     /* Set status as ARM_MATH_SIZE_MISMATCH */
92     status = ARM_MATH_SIZE_MISMATCH;
93   }
94   else
95 
96 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
97 
98   {
99     vecColBOffs[0] = 0;
100     vecColBOffs[1] = 1;
101     vecColBOffs[2] = numColsB * CMPLX_DIM;
102     vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
103     vecColBOffs[4] = 2 * numColsB * CMPLX_DIM;
104     vecColBOffs[5] = 2 * (numColsB * CMPLX_DIM) + 1;
105     vecColBOffs[6] = 3 * numColsB * CMPLX_DIM;
106     vecColBOffs[7] = 3 * (numColsB * CMPLX_DIM) + 1;
107 
108     /*
109      * Reset the variables for the usage in the following multiplication process
110      */
111     i = 0;
112     row = numRowsA;
113     px = pDst->pData;
114 
115     /*
116      * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
117      */
118 
119     /*
120      * row loop
121      */
122     while (row > 0u)
123     {
124         /*
125          * For every row wise process, the column loop counter is to be initiated
126          */
127         col = numColsB >> 1;
128         j = 0;
129         /*
130          * column loop
131          */
132         while (col > 0u)
133         {
134             q15_t const *pSrcAVec;
135             //, *pSrcBVec, *pSrcB2Vec;
136             q15x8_t vecA, vecB, vecB2;
137             q63_t     acc0, acc1, acc2, acc3;
138 
139             /*
140              * Initiate the pointer pIn1 to point to the starting address of the column being processed
141              */
142             pInA = pSrcA->pData + i;
143             pInB = pSrcB->pData + j;
144             pInB2 = pInB + CMPLX_DIM;
145 
146             j += 2 * CMPLX_DIM;
147             /*
148              * Decrement the column loop counter
149              */
150             col--;
151 
152             /*
153              * Initiate the pointers
154              * - current Matrix A rows
155              * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
156              */
157             pSrcAVec = (q15_t const *) pInA;
158 
159             acc0 = 0LL;
160             acc1 = 0LL;
161             acc2 = 0LL;
162             acc3 = 0LL;
163 
164             vecOffs = vecColBOffs;
165 
166 
167             blkCnt = (numColsA * CMPLX_DIM) >> 3;
168             while (blkCnt > 0U)
169             {
170                 vecA = vld1q(pSrcAVec);
171                 pSrcAVec += 8;
172                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
173 
174                 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
175                 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
176                 vecB2 = vldrhq_gather_shifted_offset(pInB2, vecOffs);
177                 /*
178                  * move Matrix B read offsets, 4 rows down
179                  */
180                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
181 
182                 acc2 = vmlsldavaq_s16(acc2, vecA, vecB2);
183                 acc3 = vmlaldavaxq_s16(acc3, vecA, vecB2);
184 
185                 blkCnt--;
186             }
187 
188             /*
189              * tail
190              */
191             blkCnt = (numColsA * CMPLX_DIM) & 7;
192             if (blkCnt > 0U)
193             {
194                 mve_pred16_t p0 = vctp16q(blkCnt);
195                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
196 
197                 vecA = vldrhq_z_s16(pSrcAVec, p0);
198 
199                 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
200                 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
201                 vecB2 = vldrhq_gather_shifted_offset(pInB2, vecOffs);
202 
203                 /*
204                  * move Matrix B read offsets, 4 rows down
205                  */
206                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
207 
208                 acc2 = vmlsldavaq_s16(acc2, vecA, vecB2);
209                 acc3 = vmlaldavaxq_s16(acc3, vecA, vecB2);
210 
211             }
212             /*
213              * Convert to 1.15, Store the results (1 x 2 block) in the destination buffer
214              */
215             *px++ = (q15_t)MVE_ASRL_SAT16(acc0, 15);
216             *px++ = (q15_t)MVE_ASRL_SAT16(acc1, 15);
217             *px++ = (q15_t)MVE_ASRL_SAT16(acc2, 15);
218             *px++ = (q15_t)MVE_ASRL_SAT16(acc3, 15);
219         }
220 
221         col = numColsB & 1;
222         /*
223          * column loop
224          */
225         while (col > 0u)
226         {
227 
228             q15_t const *pSrcAVec;
229             //, *pSrcBVec, *pSrcB2Vec;
230             q15x8_t vecA, vecB;
231             q63_t     acc0, acc1;
232 
233             /*
234              * Initiate the pointer pIn1 to point to the starting address of the column being processed
235              */
236             pInA = pSrcA->pData + i;
237             pInB = pSrcB->pData + j;
238 
239             j += CMPLX_DIM;
240             /*
241              * Decrement the column loop counter
242              */
243             col--;
244 
245             /*
246              * Initiate the pointers
247              * - current Matrix A rows
248              * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
249              */
250             pSrcAVec = (q15_t const *) pInA;
251 
252             acc0 = 0LL;
253             acc1 = 0LL;
254 
255 
256             vecOffs = vecColBOffs;
257 
258 
259 
260             blkCnt = (numColsA * CMPLX_DIM) >> 3;
261             while (blkCnt > 0U)
262             {
263                 vecA = vld1q(pSrcAVec);
264                 pSrcAVec += 8;
265                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
266 
267                 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
268                 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
269                 /*
270                  * move Matrix B read offsets, 4 rows down
271                  */
272                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
273 
274                 blkCnt--;
275             }
276 
277             /*
278              * tail
279              */
280             blkCnt = (numColsA * CMPLX_DIM) & 7;
281             if (blkCnt > 0U)
282             {
283                 mve_pred16_t p0 = vctp16q(blkCnt);
284                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
285                 vecA = vldrhq_z_s16(pSrcAVec, p0);
286 
287                 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
288                 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
289 
290             }
291             /*
292              * Convert to 1.15, Store the results (1 x 2 block) in the destination buffer
293              */
294             *px++ = (q15_t)MVE_ASRL_SAT16(acc0, 15);
295             *px++ = (q15_t)MVE_ASRL_SAT16(acc1, 15);
296 
297         }
298 
299         i = i + numColsA * CMPLX_DIM;
300 
301         /*
302          * Decrement the row loop counter
303          */
304         row--;
305     }
306 
307 
308     status = ARM_MATH_SUCCESS;
309   }
310 
311   /* Return to application */
312   return (status);
313 }
314 #else
arm_mat_cmplx_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pScratch)315 arm_status arm_mat_cmplx_mult_q15(
316   const arm_matrix_instance_q15 * pSrcA,
317   const arm_matrix_instance_q15 * pSrcB,
318         arm_matrix_instance_q15 * pDst,
319         q15_t                   * pScratch)
320 {
321         q15_t *pSrcBT = pScratch;                      /* input data matrix pointer for transpose */
322         q15_t *pInA = pSrcA->pData;                    /* input data matrix pointer A of Q15 type */
323         q15_t *pInB = pSrcB->pData;                    /* input data matrix pointer B of Q15 type */
324         q15_t *px;                                     /* Temporary output data matrix pointer */
325         uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A */
326         uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
327         uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
328         uint16_t numRowsB = pSrcB->numRows;            /* number of rows of input matrix A */
329         q63_t sumReal, sumImag;                        /* accumulator */
330         uint32_t col, i = 0U, row = numRowsB, colCnt;  /* Loop counters */
331         arm_status status;                             /* Status of matrix multiplication */
332 
333 #if defined (ARM_MATH_DSP)
334         q31_t prod1, prod2;
335         q31_t pSourceA, pSourceB;
336 #else
337         q15_t a, b, c, d;
338 #endif /* #if defined (ARM_MATH_DSP) */
339 
340 #ifdef ARM_MATH_MATRIX_CHECK
341 
342   /* Check for matrix mismatch condition */
343   if ((pSrcA->numCols != pSrcB->numRows) ||
344       (pSrcA->numRows != pDst->numRows)  ||
345       (pSrcB->numCols != pDst->numCols)    )
346   {
347     /* Set status as ARM_MATH_SIZE_MISMATCH */
348     status = ARM_MATH_SIZE_MISMATCH;
349   }
350   else
351 
352 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
353 
354   {
355     /* Matrix transpose */
356     do
357     {
358       /* The pointer px is set to starting address of column being processed */
359       px = pSrcBT + i;
360 
361 #if defined (ARM_MATH_LOOPUNROLL)
362 
363       /* Apply loop unrolling and exchange the columns with row elements */
364       col = numColsB >> 2;
365 
366       /* First part of the processing with loop unrolling.  Compute 4 outputs at a time.
367          a second loop below computes the remaining 1 to 3 samples. */
368       while (col > 0U)
369       {
370         /* Read two elements from row */
371         write_q15x2 (px, read_q15x2_ia (&pInB));
372 
373         /* Update pointer px to point to next row of transposed matrix */
374         px += numRowsB * 2;
375 
376         /* Read two elements from row */
377         write_q15x2 (px, read_q15x2_ia (&pInB));
378 
379         /* Update pointer px to point to next row of transposed matrix */
380         px += numRowsB * 2;
381 
382         /* Read two elements from row */
383         write_q15x2 (px, read_q15x2_ia (&pInB));
384 
385         /* Update pointer px to point to next row of transposed matrix */
386         px += numRowsB * 2;
387 
388         /* Read two elements from row */
389         write_q15x2 (px, read_q15x2_ia (&pInB));
390 
391         /* Update pointer px to point to next row of transposed matrix */
392         px += numRowsB * 2;
393 
394         /* Decrement column loop counter */
395         col--;
396       }
397 
398       /* If the columns of pSrcB is not a multiple of 4, compute any remaining output samples here.
399        ** No loop unrolling is used. */
400       col = numColsB % 0x4U;
401 
402 #else
403 
404         /* Initialize blkCnt with number of samples */
405         col = numColsB;
406 
407 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
408 
409       while (col > 0U)
410       {
411         /* Read two elements from row */
412         write_q15x2 (px, read_q15x2_ia (&pInB));
413 
414         /* Update pointer px to point to next row of transposed matrix */
415         px += numRowsB * 2;
416 
417         /* Decrement column loop counter */
418         col--;
419       }
420 
421       i = i + 2U;
422 
423       /* Decrement row loop counter */
424       row--;
425 
426     } while (row > 0U);
427 
428     /* Reset variables for usage in following multiplication process */
429     row = numRowsA;
430     i = 0U;
431     px = pDst->pData;
432 
433     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
434     /* row loop */
435     do
436     {
437       /* For every row wise process, column loop counter is to be initiated */
438       col = numColsB;
439 
440       /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
441       pInB = pSrcBT;
442 
443       /* column loop */
444       do
445       {
446         /* Set variable sum, that acts as accumulator, to zero */
447         sumReal = 0;
448         sumImag = 0;
449 
450         /* Initiate pointer pInA to point to starting address of column being processed */
451         pInA = pSrcA->pData + i * 2;
452 
453         /* Apply loop unrolling and compute 2 MACs simultaneously. */
454         colCnt = numColsA >> 1U;
455 
456         /* matrix multiplication */
457         while (colCnt > 0U)
458         {
459           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
460 
461 #if defined (ARM_MATH_DSP)
462 
463           /* read real and imag values from pSrcA and pSrcB buffer */
464           pSourceA = read_q15x2_ia ((q15_t **) &pInA);
465           pSourceB = read_q15x2_ia ((q15_t **) &pInB);
466 
467           /* Multiply and Accumlates */
468 #ifdef ARM_MATH_BIG_ENDIAN
469           prod1 = -__SMUSD(pSourceA, pSourceB);
470 #else
471           prod1 = __SMUSD(pSourceA, pSourceB);
472 #endif
473           prod2 = __SMUADX(pSourceA, pSourceB);
474           sumReal += (q63_t) prod1;
475           sumImag += (q63_t) prod2;
476 
477           /* read real and imag values from pSrcA and pSrcB buffer */
478           pSourceA = read_q15x2_ia ((q15_t **) &pInA);
479           pSourceB = read_q15x2_ia ((q15_t **) &pInB);
480 
481           /* Multiply and Accumlates */
482 #ifdef ARM_MATH_BIG_ENDIAN
483           prod1 = -__SMUSD(pSourceA, pSourceB);
484 #else
485           prod1 = __SMUSD(pSourceA, pSourceB);
486 #endif
487           prod2 = __SMUADX(pSourceA, pSourceB);
488           sumReal += (q63_t) prod1;
489           sumImag += (q63_t) prod2;
490 
491 #else /* #if defined (ARM_MATH_DSP) */
492 
493           /* read real and imag values from pSrcA buffer */
494           a = *pInA;
495           b = *(pInA + 1U);
496           /* read real and imag values from pSrcB buffer */
497           c = *pInB;
498           d = *(pInB + 1U);
499 
500           /* Multiply and Accumlates */
501           sumReal += (q31_t) a *c;
502           sumImag += (q31_t) a *d;
503           sumReal -= (q31_t) b *d;
504           sumImag += (q31_t) b *c;
505 
506           /* read next real and imag values from pSrcA buffer */
507           a = *(pInA + 2U);
508           b = *(pInA + 3U);
509           /* read next real and imag values from pSrcB buffer */
510           c = *(pInB + 2U);
511           d = *(pInB + 3U);
512 
513           /* update pointer */
514           pInA += 4U;
515 
516           /* Multiply and Accumlates */
517           sumReal += (q31_t) a * c;
518           sumImag += (q31_t) a * d;
519           sumReal -= (q31_t) b * d;
520           sumImag += (q31_t) b * c;
521           /* update pointer */
522           pInB += 4U;
523 
524 #endif /* #if defined (ARM_MATH_DSP) */
525 
526           /* Decrement loop counter */
527           colCnt--;
528         }
529 
530         /* process odd column samples */
531         if ((numColsA & 0x1U) > 0U)
532         {
533           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
534 
535 #if defined (ARM_MATH_DSP)
536           /* read real and imag values from pSrcA and pSrcB buffer */
537           pSourceA = read_q15x2_ia ((q15_t **) &pInA);
538           pSourceB = read_q15x2_ia ((q15_t **) &pInB);
539 
540           /* Multiply and Accumlates */
541 #ifdef ARM_MATH_BIG_ENDIAN
542           prod1 = -__SMUSD(pSourceA, pSourceB);
543 #else
544           prod1 = __SMUSD(pSourceA, pSourceB);
545 #endif
546           prod2 = __SMUADX(pSourceA, pSourceB);
547           sumReal += (q63_t) prod1;
548           sumImag += (q63_t) prod2;
549 
550 #else /* #if defined (ARM_MATH_DSP) */
551 
552           /* read real and imag values from pSrcA and pSrcB buffer */
553           a = *pInA++;
554           b = *pInA++;
555           c = *pInB++;
556           d = *pInB++;
557 
558           /* Multiply and Accumlates */
559           sumReal += (q31_t) a * c;
560           sumImag += (q31_t) a * d;
561           sumReal -= (q31_t) b * d;
562           sumImag += (q31_t) b * c;
563 
564 #endif /* #if defined (ARM_MATH_DSP) */
565 
566         }
567 
568         /* Saturate and store result in destination buffer */
569         *px++ = (q15_t) (__SSAT(sumReal >> 15, 16));
570         *px++ = (q15_t) (__SSAT(sumImag >> 15, 16));
571 
572         /* Decrement column loop counter */
573         col--;
574 
575       } while (col > 0U);
576 
577       i = i + numColsA;
578 
579       /* Decrement row loop counter */
580       row--;
581 
582     } while (row > 0U);
583 
584     /* Set status as ARM_MATH_SUCCESS */
585     status = ARM_MATH_SUCCESS;
586   }
587 
588   /* Return to application */
589   return (status);
590 }
591 #endif /* defined(ARM_MATH_MVEI) */
592 
593 /**
594   @} end of MatrixMult group
595  */
596