1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_cmplx_mult_f16.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 
34 /**
35   @ingroup groupMatrix
36  */
37 
38 
39 /**
40   @addtogroup CmplxMatrixMult
41   @{
42  */
43 
44 /**
45   @brief         Floating-point Complex matrix multiplication.
46   @param[in]     pSrcA      points to first input complex matrix structure
47   @param[in]     pSrcB      points to second input complex matrix structure
48   @param[out]    pDst       points to output complex matrix structure
49   @return        execution status
50                    - \ref ARM_MATH_SUCCESS       : Operation successful
51                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
52  */
53 
54 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H)
55 #pragma message "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc."
56 #endif
57 
58 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) &&  !defined(__CMSIS_GCC_H)
59 
60 #include "arm_helium_utils.h"
61 
62 #define DONTCARE            0 /* inactive lane content */
63 
64 
arm_mat_cmplx_mult_f16_2x2_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)65 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve(
66     const arm_matrix_instance_f16 * pSrcA,
67     const arm_matrix_instance_f16 * pSrcB,
68     arm_matrix_instance_f16 * pDst)
69 {
70 #define MATRIX_DIM 2
71     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
72     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
73     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
74     uint16x8_t     vecColBOffs0,vecColAOffs0,vecColAOffs1;
75     float16_t       *pInA0 = pInA;
76     f16x8_t        acc0, acc1;
77     f16x8_t        vecB, vecA0, vecA1;
78     f16x8_t        vecTmp;
79     uint16_t         tmp;
80     static const uint16_t offsetB0[8] = { 0, 1,
81         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
82         2, 3,
83         MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3,
84     };
85 
86 
87     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
88 
89     tmp = 0;
90     vecColAOffs0 = viwdupq_u16(tmp, 4, 1);
91 
92     tmp = (CMPLX_DIM * MATRIX_DIM);
93     vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM);
94 
95 
96     pInB = (float16_t const *)pSrcB->pData;
97 
98     vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0);
99     vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1);
100 
101 
102     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
103 
104     acc0 = vcmulq(vecA0, vecB);
105     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
106 
107     acc1 = vcmulq(vecA1, vecB);
108     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
109 
110 
111     /*
112      * Compute
113      *  re0+re1 | im0+im1 | re0+re1 | im0+im1
114      *  re2+re3 | im2+im3 | re2+re3 | im2+im3
115      */
116 
117     vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0);
118     vecTmp = vaddq(vecTmp, acc0);
119 
120 
121     *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
122     *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
123 
124     vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1);
125     vecTmp = vaddq(vecTmp, acc1);
126 
127     *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
128     *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
129 
130     /*
131      * Return to application
132      */
133     return (ARM_MATH_SUCCESS);
134 #undef MATRIX_DIM
135 }
136 
137 
138 
arm_mat_cmplx_mult_f16_3x3_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)139 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve(
140     const arm_matrix_instance_f16 * pSrcA,
141     const arm_matrix_instance_f16 * pSrcB,
142     arm_matrix_instance_f16 * pDst)
143 {
144 #define MATRIX_DIM 3
145     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
146     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
147     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
148     uint16x8_t     vecColBOffs0;
149     float16_t       *pInA0 = pInA;
150     float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
151     float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
152     f16x8_t        acc0, acc1, acc2;
153     f16x8_t        vecB, vecA0, vecA1, vecA2;
154     static const uint16_t offsetB0[8] = { 0, 1,
155         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
156         2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
157         DONTCARE, DONTCARE
158     };
159 
160 
161     /* enable predication to disable upper half complex vector element */
162     mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM);
163 
164     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
165 
166     pInB = (float16_t const *)pSrcB->pData;
167 
168     vecA0 = vldrhq_f16(pInA0);
169     vecA1 = vldrhq_f16(pInA1);
170     vecA2 = vldrhq_f16(pInA2);
171 
172     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
173 
174     acc0 = vcmulq(vecA0, vecB);
175     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
176 
177     acc1 = vcmulq(vecA1, vecB);
178     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
179 
180     acc2 = vcmulq(vecA2, vecB);
181     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
182 
183     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
184     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
185     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
186     pOut += CMPLX_DIM;
187     /*
188      * move to next B column
189      */
190     pInB = pInB + CMPLX_DIM;
191 
192     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
193 
194     acc0 = vcmulq(vecA0, vecB);
195     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
196 
197     acc1 = vcmulq(vecA1, vecB);
198     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
199 
200     acc2 = vcmulq(vecA2, vecB);
201     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
202 
203     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
204     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
205     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
206     pOut += CMPLX_DIM;
207     /*
208      * move to next B column
209      */
210     pInB = pInB + CMPLX_DIM;
211 
212     vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
213 
214     acc0 = vcmulq(vecA0, vecB);
215     acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
216 
217     acc1 = vcmulq(vecA1, vecB);
218     acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
219 
220     acc2 = vcmulq(vecA2, vecB);
221     acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
222 
223     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
224     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
225     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
226     /*
227      * Return to application
228      */
229     return (ARM_MATH_SUCCESS);
230 #undef MATRIX_DIM
231 }
232 
233 
234 
235 
arm_mat_cmplx_mult_f16_4x4_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)236 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve(
237     const arm_matrix_instance_f16 * pSrcA,
238     const arm_matrix_instance_f16 * pSrcB,
239     arm_matrix_instance_f16 * pDst)
240 {
241 #define MATRIX_DIM 4
242     float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
243     float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
244     float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
245     uint16x8_t     vecColBOffs0;
246     float16_t       *pInA0 = pInA;
247     float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
248     float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
249     float16_t       *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM;
250     f16x8_t        acc0, acc1, acc2, acc3;
251     f16x8_t        vecB, vecA;
252     static const uint16_t offsetB0[8] = { 0, 1,
253         MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
254         2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
255         3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1
256     };
257 
258     vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
259 
260     pInB = (float16_t const *)pSrcB->pData;
261 
262     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
263 
264     vecA = vldrhq_f16(pInA0);
265     acc0 = vcmulq(vecA, vecB);
266     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
267 
268     vecA = vldrhq_f16(pInA1);
269     acc1 = vcmulq(vecA, vecB);
270     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
271 
272     vecA = vldrhq_f16(pInA2);
273     acc2 = vcmulq(vecA, vecB);
274     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
275 
276     vecA = vldrhq_f16(pInA3);
277     acc3 = vcmulq(vecA, vecB);
278     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
279 
280 
281     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
282     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
283     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
284     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
285     pOut += CMPLX_DIM;
286     /*
287      * move to next B column
288      */
289     pInB = pInB + CMPLX_DIM;
290 
291     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
292 
293     vecA = vldrhq_f16(pInA0);
294     acc0 = vcmulq(vecA, vecB);
295     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
296 
297     vecA = vldrhq_f16(pInA1);
298     acc1 = vcmulq(vecA, vecB);
299     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
300 
301     vecA = vldrhq_f16(pInA2);
302     acc2 = vcmulq(vecA, vecB);
303     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
304 
305     vecA = vldrhq_f16(pInA3);
306     acc3 = vcmulq(vecA, vecB);
307     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
308 
309 
310     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
311     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
312     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
313     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
314     pOut += CMPLX_DIM;
315     /*
316      * move to next B column
317      */
318     pInB = pInB + CMPLX_DIM;
319 
320     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
321 
322     vecA = vldrhq_f16(pInA0);
323     acc0 = vcmulq(vecA, vecB);
324     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
325 
326     vecA = vldrhq_f16(pInA1);
327     acc1 = vcmulq(vecA, vecB);
328     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
329 
330     vecA = vldrhq_f16(pInA2);
331     acc2 = vcmulq(vecA, vecB);
332     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
333 
334     vecA = vldrhq_f16(pInA3);
335     acc3 = vcmulq(vecA, vecB);
336     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
337 
338 
339     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
340     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
341     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
342     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
343     pOut += CMPLX_DIM;
344     /*
345      * move to next B column
346      */
347     pInB = pInB + CMPLX_DIM;
348 
349     vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
350 
351     vecA = vldrhq_f16(pInA0);
352     acc0 = vcmulq(vecA, vecB);
353     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
354 
355     vecA = vldrhq_f16(pInA1);
356     acc1 = vcmulq(vecA, vecB);
357     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
358 
359     vecA = vldrhq_f16(pInA2);
360     acc2 = vcmulq(vecA, vecB);
361     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
362 
363     vecA = vldrhq_f16(pInA3);
364     acc3 = vcmulq(vecA, vecB);
365     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
366 
367 
368     mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
369     mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
370     mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
371     mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
372     /*
373      * Return to application
374      */
375     return (ARM_MATH_SUCCESS);
376 #undef MATRIX_DIM
377 }
378 
379 
380 
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)381 arm_status arm_mat_cmplx_mult_f16(
382   const arm_matrix_instance_f16 * pSrcA,
383   const arm_matrix_instance_f16 * pSrcB,
384   arm_matrix_instance_f16 * pDst)
385 {
386     float16_t const *pInB = (float16_t const *) pSrcB->pData;   /* input data matrix pointer B */
387     float16_t const *pInA = (float16_t const *) pSrcA->pData;   /* input data matrix pointer A */
388     float16_t *pOut = pDst->pData;  /* output data matrix pointer */
389     float16_t *px;              /* Temporary output data matrix pointer */
390     uint16_t  numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
391     uint16_t  numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
392     uint16_t  numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
393     uint16_t  col, i = 0U, row = numRowsA;  /* loop counters */
394     arm_status status;          /* status of matrix multiplication */
395     uint16x8_t vecOffs, vecColBOffs;
396     uint32_t  blkCnt,rowCnt;           /* loop counters */
397 
398     #ifdef ARM_MATH_MATRIX_CHECK
399 
400   /* Check for matrix mismatch condition */
401 if ((pSrcA->numCols != pSrcB->numRows) ||
402       (pSrcA->numRows != pDst->numRows)  ||
403       (pSrcB->numCols != pDst->numCols)    )
404   {
405     /* Set status as ARM_MATH_SIZE_MISMATCH */
406     status = ARM_MATH_SIZE_MISMATCH;
407   }
408   else
409 
410 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
411 
412   {
413 
414     /*
415      * small squared matrix specialized routines
416      */
417     if (numRowsA == numColsB && numColsB == numColsA)
418     {
419         if (numRowsA == 1)
420         {
421             pOut[0] = (_Float16)pInA[0] * (_Float16)pInB[0] - (_Float16)pInA[1] * (_Float16)pInB[1];
422             pOut[1] = (_Float16)pInA[0] * (_Float16)pInB[1] + (_Float16)pInA[1] * (_Float16)pInB[0];
423             return (ARM_MATH_SUCCESS);
424         }
425         else if  (numRowsA == 2)
426             return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
427         else if (numRowsA == 3)
428             return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
429         else if (numRowsA == 4)
430             return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
431     }
432 
433     vecColBOffs[0] = 0;
434     vecColBOffs[1] = 1;
435     vecColBOffs[2] = numColsB * CMPLX_DIM;
436     vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
437     vecColBOffs[4] = 2*numColsB * CMPLX_DIM;
438     vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1;
439     vecColBOffs[6] = 3*numColsB * CMPLX_DIM;
440     vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1;
441 
442     /*
443      * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
444      */
445 
446     /*
447      * row loop
448      */
449     rowCnt = row >> 2;
450     while (rowCnt > 0u)
451     {
452         /*
453          * Output pointer is set to starting address of the row being processed
454          */
455         px = pOut + i * CMPLX_DIM;
456         i = i + 4 * numColsB;
457         /*
458          * For every row wise process, the column loop counter is to be initiated
459          */
460         col = numColsB;
461         /*
462          * For every row wise process, the pInB pointer is set
463          * to the starting address of the pSrcB data
464          */
465         pInB = (float16_t const *) pSrcB->pData;
466         /*
467          * column loop
468          */
469         while (col > 0u)
470         {
471             /*
472              * generate 4 columns elements
473              */
474             /*
475              * Matrix A columns number of MAC operations are to be performed
476              */
477 
478             float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
479             float16_t const *pInA0 = pInA;
480             float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
481             float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
482             float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
483             f16x8_t acc0, acc1, acc2, acc3;
484 
485             acc0 = vdupq_n_f16(0.0f16);
486             acc1 = vdupq_n_f16(0.0f16);
487             acc2 = vdupq_n_f16(0.0f16);
488             acc3 = vdupq_n_f16(0.0f16);
489 
490             pSrcA0Vec = (float16_t const *) pInA0;
491             pSrcA1Vec = (float16_t const *) pInA1;
492             pSrcA2Vec = (float16_t const *) pInA2;
493             pSrcA3Vec = (float16_t const *) pInA3;
494 
495             vecOffs = vecColBOffs;
496 
497             /*
498              * process 1 x 4 block output
499              */
500             blkCnt = (numColsA * CMPLX_DIM) >> 3;
501             while (blkCnt > 0U)
502             {
503                 f16x8_t vecB, vecA;
504 
505                 vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs);
506                 /*
507                  * move Matrix B read offsets, 4 rows down
508                  */
509                 vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM));
510 
511                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 8;
512                 acc0 = vcmlaq(acc0, vecA, vecB);
513                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
514 
515                 vecA = vld1q(pSrcA1Vec);  pSrcA1Vec += 8;
516                 acc1 = vcmlaq(acc1, vecA, vecB);
517                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
518 
519                 vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 8;
520                 acc2 = vcmlaq(acc2, vecA, vecB);
521                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
522 
523                 vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 8;
524                 acc3 = vcmlaq(acc3, vecA, vecB);
525                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
526 
527                 blkCnt--;
528             }
529             /*
530              * Unsupported addressing mode compiler crash
531              */
532             /*
533              * tail
534              * (will be merged thru tail predication)
535              */
536             blkCnt = (numColsA * CMPLX_DIM) & 7;
537             if (blkCnt > 0U)
538             {
539                 mve_pred16_t p0 = vctp16q(blkCnt);
540                 f16x8_t vecB, vecA;
541 
542                 vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0);
543                 /*
544                  * move Matrix B read offsets, 4 rows down
545                  */
546                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
547 
548                 vecA = vld1q(pSrcA0Vec);
549                 acc0 = vcmlaq(acc0, vecA, vecB);
550                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
551 
552                 vecA = vld1q(pSrcA1Vec);
553                 acc1 = vcmlaq(acc1, vecA, vecB);
554                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
555 
556                 vecA = vld1q(pSrcA2Vec);
557                 acc2 = vcmlaq(acc2, vecA, vecB);
558                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
559 
560                 vecA = vld1q(pSrcA3Vec);
561                 acc3 = vcmlaq(acc3, vecA, vecB);
562                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
563 
564             }
565 
566 
567             mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]);
568             mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]);
569             mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]);
570             mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]);
571 
572             px += CMPLX_DIM;
573             /*
574              * Decrement the column loop counter
575              */
576             col--;
577             /*
578              * Update the pointer pInB to point to the  starting address of the next column
579              */
580             pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
581         }
582 
583         /*
584          * Update the pointer pInA to point to the  starting address of the next row
585          */
586         pInA += (numColsA * 4) * CMPLX_DIM;
587         /*
588          * Decrement the row loop counter
589          */
590         rowCnt --;
591 
592     }
593 
594     rowCnt = row & 3;
595     while (rowCnt > 0u)
596     {
597            /*
598          * Output pointer is set to starting address of the row being processed
599          */
600         px = pOut + i * CMPLX_DIM;
601         i = i + numColsB;
602         /*
603          * For every row wise process, the column loop counter is to be initiated
604          */
605         col = numColsB;
606         /*
607          * For every row wise process, the pInB pointer is set
608          * to the starting address of the pSrcB data
609          */
610         pInB = (float16_t const *) pSrcB->pData;
611         /*
612          * column loop
613          */
614         while (col > 0u)
615         {
616             /*
617              * generate 4 columns elements
618              */
619             /*
620              * Matrix A columns number of MAC operations are to be performed
621              */
622 
623             float16_t const *pSrcA0Vec;
624             float16_t const *pInA0 = pInA;
625             f16x8_t acc0;
626 
627             acc0 = vdupq_n_f16(0.0f16);
628 
629             pSrcA0Vec = (float16_t const *) pInA0;
630 
631             vecOffs = vecColBOffs;
632 
633             /*
634              * process 1 x 4 block output
635              */
636             blkCnt = (numColsA * CMPLX_DIM) >> 3;
637             while (blkCnt > 0U)
638             {
639                 f16x8_t vecB, vecA;
640 
641                 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
642                 /*
643                  * move Matrix B read offsets, 4 rows down
644                  */
645                 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM));
646 
647                 vecA = vld1q(pSrcA0Vec);
648                 pSrcA0Vec += 8;
649                 acc0 = vcmlaq(acc0, vecA, vecB);
650                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
651 
652 
653                 blkCnt--;
654             }
655 
656 
657             /*
658              * tail
659              */
660             blkCnt = (numColsA * CMPLX_DIM) & 7;
661             if (blkCnt > 0U)
662             {
663                 mve_pred16_t p0 = vctp16q(blkCnt);
664                 f16x8_t vecB, vecA;
665 
666                 vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0);
667 
668                 vecA = vld1q(pSrcA0Vec);
669                 acc0 = vcmlaq(acc0, vecA, vecB);
670                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
671 
672             }
673 
674             mve_cmplx_sum_intra_vec_f16(acc0, &px[0]);
675 
676 
677             px += CMPLX_DIM;
678             /*
679              * Decrement the column loop counter
680              */
681             col--;
682             /*
683              * Update the pointer pInB to point to the  starting address of the next column
684              */
685             pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
686         }
687 
688         /*
689          * Update the pointer pInA to point to the  starting address of the next row
690          */
691         pInA += numColsA  * CMPLX_DIM;
692         rowCnt--;
693     }
694 
695     /*
696      * set status as ARM_MATH_SUCCESS
697      */
698     status = ARM_MATH_SUCCESS;
699  }
700     /*
701      * Return to application
702      */
703     return (status);
704 }
705 #else
706 
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)707 arm_status arm_mat_cmplx_mult_f16(
708   const arm_matrix_instance_f16 * pSrcA,
709   const arm_matrix_instance_f16 * pSrcB,
710         arm_matrix_instance_f16 * pDst)
711 {
712   float16_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
713   float16_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
714   float16_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
715   float16_t *pOut = pDst->pData;                 /* Output data matrix pointer */
716   float16_t *px;                                 /* Temporary output data matrix pointer */
717   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
718   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
719   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
720   _Float16 sumReal, sumImag;                    /* Accumulator */
721   _Float16 a1, b1, c1, d1;
722   uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
723   arm_status status;                             /* status of matrix multiplication */
724 
725 #if defined (ARM_MATH_LOOPUNROLL)
726   _Float16 a0, b0, c0, d0;
727 #endif
728 
729 #ifdef ARM_MATH_MATRIX_CHECK
730 
731   /* Check for matrix mismatch condition */
732   if ((pSrcA->numCols != pSrcB->numRows) ||
733       (pSrcA->numRows != pDst->numRows)  ||
734       (pSrcB->numCols != pDst->numCols)    )
735   {
736     /* Set status as ARM_MATH_SIZE_MISMATCH */
737     status = ARM_MATH_SIZE_MISMATCH;
738   }
739   else
740 
741 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
742 
743   {
744     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
745     /* row loop */
746     do
747     {
748       /* Output pointer is set to starting address of the row being processed */
749       px = pOut + 2 * i;
750 
751       /* For every row wise process, the column loop counter is to be initiated */
752       col = numColsB;
753 
754       /* For every row wise process, the pIn2 pointer is set
755        ** to the starting address of the pSrcB data */
756       pIn2 = pSrcB->pData;
757 
758       j = 0U;
759 
760       /* column loop */
761       do
762       {
763         /* Set the variable sum, that acts as accumulator, to zero */
764         sumReal = 0.0f16;
765         sumImag = 0.0f16;
766 
767         /* Initiate pointer pIn1 to point to starting address of column being processed */
768         pIn1 = pInA;
769 
770 #if defined (ARM_MATH_LOOPUNROLL)
771 
772         /* Apply loop unrolling and compute 4 MACs simultaneously. */
773         colCnt = numColsA >> 2U;
774 
775         /* matrix multiplication */
776         while (colCnt > 0U)
777         {
778 
779           /* Reading real part of complex matrix A */
780           a0 = *pIn1;
781 
782           /* Reading real part of complex matrix B */
783           c0 = *pIn2;
784 
785           /* Reading imaginary part of complex matrix A */
786           b0 = *(pIn1 + 1U);
787 
788           /* Reading imaginary part of complex matrix B */
789           d0 = *(pIn2 + 1U);
790 
791           /* Multiply and Accumlates */
792           sumReal += a0 * c0;
793           sumImag += b0 * c0;
794 
795           /* update pointers */
796           pIn1 += 2U;
797           pIn2 += 2 * numColsB;
798 
799           /* Multiply and Accumlates */
800           sumReal -= b0 * d0;
801           sumImag += a0 * d0;
802 
803           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
804 
805           /* read real and imag values from pSrcA and pSrcB buffer */
806           a1 = *(pIn1     );
807           c1 = *(pIn2     );
808           b1 = *(pIn1 + 1U);
809           d1 = *(pIn2 + 1U);
810 
811           /* Multiply and Accumlates */
812           sumReal += a1 * c1;
813           sumImag += b1 * c1;
814 
815           /* update pointers */
816           pIn1 += 2U;
817           pIn2 += 2 * numColsB;
818 
819           /* Multiply and Accumlates */
820           sumReal -= b1 * d1;
821           sumImag += a1 * d1;
822 
823           a0 = *(pIn1     );
824           c0 = *(pIn2     );
825           b0 = *(pIn1 + 1U);
826           d0 = *(pIn2 + 1U);
827 
828           /* Multiply and Accumlates */
829           sumReal += a0 * c0;
830           sumImag += b0 * c0;
831 
832           /* update pointers */
833           pIn1 += 2U;
834           pIn2 += 2 * numColsB;
835 
836           /* Multiply and Accumlates */
837           sumReal -= b0 * d0;
838           sumImag += a0 * d0;
839 
840           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
841 
842           a1 = *(pIn1     );
843           c1 = *(pIn2     );
844           b1 = *(pIn1 + 1U);
845           d1 = *(pIn2 + 1U);
846 
847           /* Multiply and Accumlates */
848           sumReal += a1 * c1;
849           sumImag += b1 * c1;
850 
851           /* update pointers */
852           pIn1 += 2U;
853           pIn2 += 2 * numColsB;
854 
855           /* Multiply and Accumlates */
856           sumReal -= b1 * d1;
857           sumImag += a1 * d1;
858 
859           /* Decrement loop count */
860           colCnt--;
861         }
862 
863         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
864          ** No loop unrolling is used. */
865         colCnt = numColsA % 0x4U;
866 
867 #else
868 
869         /* Initialize blkCnt with number of samples */
870         colCnt = numColsA;
871 
872 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
873 
874         while (colCnt > 0U)
875         {
876           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
877           a1 = *(pIn1     );
878           c1 = *(pIn2     );
879           b1 = *(pIn1 + 1U);
880           d1 = *(pIn2 + 1U);
881 
882           /* Multiply and Accumlates */
883           sumReal += a1 * c1;
884           sumImag += b1 * c1;
885 
886           /* update pointers */
887           pIn1 += 2U;
888           pIn2 += 2 * numColsB;
889 
890           /* Multiply and Accumlates */
891           sumReal -= b1 * d1;
892           sumImag += a1 * d1;
893 
894           /* Decrement loop counter */
895           colCnt--;
896         }
897 
898         /* Store result in destination buffer */
899         *px++ = sumReal;
900         *px++ = sumImag;
901 
902         /* Update pointer pIn2 to point to starting address of next column */
903         j++;
904         pIn2 = pSrcB->pData + 2U * j;
905 
906         /* Decrement column loop counter */
907         col--;
908 
909       } while (col > 0U);
910 
911       /* Update pointer pInA to point to starting address of next row */
912       i = i + numColsB;
913       pInA = pInA + 2 * numColsA;
914 
915       /* Decrement row loop counter */
916       row--;
917 
918     } while (row > 0U);
919 
920     /* Set status as ARM_MATH_SUCCESS */
921     status = ARM_MATH_SUCCESS;
922   }
923 
924   /* Return to application */
925   return (status);
926 }
927 
928 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
929 
930 /**
931   @} end of MatrixMult group
932  */
933 
934 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
935 
936