1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_cmplx_mult_f32.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @defgroup CmplxMatrixMult  Complex Matrix Multiplication
37 
38   Complex Matrix multiplication is only defined if the number of columns of the
39   first matrix equals the number of rows of the second matrix.
40   Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
41   in an <code>M x P</code> matrix.
42   @par
43   When matrix size checking is enabled, the functions check:
44    - that the inner dimensions of <code>pSrcA</code> and <code>pSrcB</code> are equal;
45    - that the size of the output matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
46  */
47 
48 
49 /**
50   @addtogroup CmplxMatrixMult
51   @{
52  */
53 
54 /**
55   @brief         Floating-point Complex matrix multiplication.
56   @param[in]     pSrcA      points to first input complex matrix structure
57   @param[in]     pSrcB      points to second input complex matrix structure
58   @param[out]    pDst       points to output complex matrix structure
59   @return        execution status
60                    - \ref ARM_MATH_SUCCESS       : Operation successful
61                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
62  */
63 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
64 
65 #include "arm_helium_utils.h"
66 
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70 
arm_mat_cmplx_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)71 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_2x2_mve(
72     const arm_matrix_instance_f32 * pSrcA,
73     const arm_matrix_instance_f32 * pSrcB,
74     arm_matrix_instance_f32 * pDst)
75 {
76     float32_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
77     float32_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
78     float32_t       *pOut = pDst->pData;   /* output data matrix pointer */
79     uint32x4_t   vecColBOffs0;
80     float32_t       *pInA0 = pInA;
81     float32_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM2;
82     f32x4_t    acc0, acc1;
83     f32x4_t    vecB, vecA;
84 
85     static const uint32_t offsetB0[4] = { 0, 1,
86         MATRIX_DIM2 * CMPLX_DIM, MATRIX_DIM2 * CMPLX_DIM + 1
87     };
88 
89     vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
90 
91     pInB = (float32_t const *)pSrcB->pData;
92 
93     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
94 
95     vecA = vldrwq_f32(pInA0);
96     acc0 = vcmulq(vecA, vecB);
97     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
98 
99     vecA = vldrwq_f32(pInA1);
100     acc1 = vcmulq(vecA, vecB);
101     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
102 
103     pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc0[0] + acc0[2];
104     pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc0[1] + acc0[3];
105     pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc1[0] + acc1[2];
106     pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc1[1] + acc1[3];
107     pOut += CMPLX_DIM;
108 
109     /*
110      * move to next B column
111      */
112     pInB = pInB + CMPLX_DIM;
113 
114     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
115 
116     vecA = vldrwq_f32(pInA0);
117     acc0 = vcmulq(vecA, vecB);
118     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
119 
120     vecA = vldrwq_f32(pInA1);
121     acc1 = vcmulq(vecA, vecB);
122     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
123 
124     pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc0[0] + acc0[2];
125     pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc0[1] + acc0[3];
126     pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc1[0] + acc1[2];
127     pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc1[1] + acc1[3];
128     /*
129      * Return to application
130      */
131     return (ARM_MATH_SUCCESS);
132 }
133 
134 
arm_mat_cmplx_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)135 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_3x3_mve(
136     const arm_matrix_instance_f32 * pSrcA,
137     const arm_matrix_instance_f32 * pSrcB,
138     arm_matrix_instance_f32 * pDst)
139 {
140     float32_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
141     float32_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
142     float32_t       *pOut = pDst->pData;   /* output data matrix pointer */
143     uint32x4_t   vecColBOffs0, vecColBOffs1;
144     float32_t       *pInA0 = pInA;
145     float32_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM3;
146     float32_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM3;
147     f32x4_t    acc0, acc1, acc2;
148     f32x4_t    vecB, vecA;
149     /* enable predication to disable upper half complex vector element */
150     mve_pred16_t p0 = vctp32q(CMPLX_DIM);
151 
152     static const uint32_t offsetB0[4] = { 0, 1,
153         MATRIX_DIM3 * CMPLX_DIM, MATRIX_DIM3 * CMPLX_DIM + 1
154     };
155     static const uint32_t offsetB1[4] = { 2 * MATRIX_DIM3 * CMPLX_DIM, 2 * MATRIX_DIM3 * CMPLX_DIM + 1,
156        INACTIVELANE, INACTIVELANE
157     };
158 
159     vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
160     vecColBOffs1 = vldrwq_u32((uint32_t const *) offsetB1);
161 
162     pInB = (float32_t const *)pSrcB->pData;
163 
164     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
165 
166     vecA = vldrwq_f32(pInA0);
167     acc0 = vcmulq(vecA, vecB);
168     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
169 
170     vecA = vldrwq_f32(pInA1);
171     acc1 = vcmulq(vecA, vecB);
172     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
173 
174     vecA = vldrwq_f32(pInA2);
175     acc2 = vcmulq(vecA, vecB);
176     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
177 
178 
179     vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
180 
181     vecA = vldrwq_f32(&pInA0[4]);
182     acc0 = vcmlaq(acc0, vecA, vecB);
183     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
184 
185     vecA = vldrwq_f32(&pInA1[4]);
186     acc1 = vcmlaq(acc1, vecA, vecB);
187     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
188 
189     vecA = vldrwq_f32(&pInA2[4]);
190     acc2 = vcmlaq(acc2, vecA, vecB);
191     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
192 
193 
194     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
195     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
196     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
197     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
198     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
199     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
200     pOut += CMPLX_DIM;
201 
202     /*
203      * move to next B column
204      */
205     pInB = pInB + CMPLX_DIM;
206 
207     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
208 
209     vecA = vldrwq_f32(pInA0);
210     acc0 = vcmulq(vecA, vecB);
211     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
212 
213     vecA = vldrwq_f32(pInA1);
214     acc1 = vcmulq(vecA, vecB);
215     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
216 
217     vecA = vldrwq_f32(pInA2);
218     acc2 = vcmulq(vecA, vecB);
219     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
220 
221      vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
222 
223     vecA = vldrwq_f32(&pInA0[4]);
224     acc0 = vcmlaq(acc0, vecA, vecB);
225     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
226 
227     vecA = vldrwq_f32(&pInA1[4]);
228     acc1 = vcmlaq(acc1, vecA, vecB);
229     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
230 
231     vecA = vldrwq_f32(&pInA2[4]);
232     acc2 = vcmlaq(acc2, vecA, vecB);
233     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
234 
235 
236     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
237     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
238     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
239     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
240     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
241     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
242     pOut += CMPLX_DIM;
243 
244     /*
245      * move to next B column
246      */
247     pInB = pInB + CMPLX_DIM;
248 
249     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
250 
251     vecA = vldrwq_f32(pInA0);
252     acc0 = vcmulq(vecA, vecB);
253     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
254 
255     vecA = vldrwq_f32(pInA1);
256     acc1 = vcmulq(vecA, vecB);
257     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
258 
259     vecA = vldrwq_f32(pInA2);
260     acc2 = vcmulq(vecA, vecB);
261     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
262 
263      vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
264 
265     vecA = vldrwq_f32(&pInA0[4]);
266     acc0 = vcmlaq(acc0, vecA, vecB);
267     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
268 
269     vecA = vldrwq_f32(&pInA1[4]);
270     acc1 = vcmlaq(acc1, vecA, vecB);
271     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
272 
273     vecA = vldrwq_f32(&pInA2[4]);
274     acc2 = vcmlaq(acc2, vecA, vecB);
275     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
276 
277 
278     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
279     pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
280     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
281     pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
282     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
283     pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
284     /*
285      * Return to application
286      */
287     return (ARM_MATH_SUCCESS);
288 }
289 
290 
291 
arm_mat_cmplx_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)292 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_4x4_mve(
293     const arm_matrix_instance_f32 * pSrcA,
294     const arm_matrix_instance_f32 * pSrcB,
295     arm_matrix_instance_f32 * pDst)
296 {
297     float32_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
298     float32_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
299     float32_t       *pOut = pDst->pData;   /* output data matrix pointer */
300     uint32x4_t   vecColBOffs0, vecColBOffs1;
301     float32_t       *pInA0 = pInA;
302     float32_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM4;
303     float32_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM4;
304     float32_t       *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM4;
305     f32x4_t    acc0, acc1, acc2, acc3;
306     f32x4_t    vecB, vecA;
307 
308     static const uint32_t offsetB0[4] = { 0, 1,
309         MATRIX_DIM4 * CMPLX_DIM, MATRIX_DIM4 * CMPLX_DIM + 1
310     };
311     static const uint32_t offsetB1[4] = { 2 * MATRIX_DIM4 * CMPLX_DIM, 2 * MATRIX_DIM4 * CMPLX_DIM + 1,
312         3 * MATRIX_DIM4 * CMPLX_DIM, 3 * MATRIX_DIM4 * CMPLX_DIM + 1
313     };
314 
315     vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
316     vecColBOffs1 = vldrwq_u32((uint32_t const *) offsetB1);
317 
318     pInB = (float32_t const *)pSrcB->pData;
319 
320     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
321 
322     vecA = vldrwq_f32(pInA0);
323     acc0 = vcmulq(vecA, vecB);
324     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
325 
326     vecA = vldrwq_f32(pInA1);
327     acc1 = vcmulq(vecA, vecB);
328     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
329 
330     vecA = vldrwq_f32(pInA2);
331     acc2 = vcmulq(vecA, vecB);
332     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
333 
334     vecA = vldrwq_f32(pInA3);
335     acc3 = vcmulq(vecA, vecB);
336     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
337 
338     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
339 
340     vecA = vldrwq_f32(&pInA0[4]);
341     acc0 = vcmlaq(acc0, vecA, vecB);
342     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
343 
344     vecA = vldrwq_f32(&pInA1[4]);
345     acc1 = vcmlaq(acc1, vecA, vecB);
346     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
347 
348     vecA = vldrwq_f32(&pInA2[4]);
349     acc2 = vcmlaq(acc2, vecA, vecB);
350     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
351 
352     vecA = vldrwq_f32(&pInA3[4]);
353     acc3 = vcmlaq(acc3, vecA, vecB);
354     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
355 
356     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
357     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
358     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
359     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
360     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
361     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
362     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
363     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
364     pOut += CMPLX_DIM;
365 
366     /*
367      * move to next B column
368      */
369     pInB = pInB + CMPLX_DIM;
370 
371     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
372 
373     vecA = vldrwq_f32(pInA0);
374     acc0 = vcmulq(vecA, vecB);
375     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
376 
377     vecA = vldrwq_f32(pInA1);
378     acc1 = vcmulq(vecA, vecB);
379     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
380 
381     vecA = vldrwq_f32(pInA2);
382     acc2 = vcmulq(vecA, vecB);
383     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
384 
385     vecA = vldrwq_f32(pInA3);
386     acc3 = vcmulq(vecA, vecB);
387     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
388 
389     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
390 
391     vecA = vldrwq_f32(&pInA0[4]);
392     acc0 = vcmlaq(acc0, vecA, vecB);
393     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
394 
395     vecA = vldrwq_f32(&pInA1[4]);
396     acc1 = vcmlaq(acc1, vecA, vecB);
397     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
398 
399     vecA = vldrwq_f32(&pInA2[4]);
400     acc2 = vcmlaq(acc2, vecA, vecB);
401     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
402 
403     vecA = vldrwq_f32(&pInA3[4]);
404     acc3 = vcmlaq(acc3, vecA, vecB);
405     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
406 
407     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
408     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
409     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
410     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
411     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
412     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
413     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
414     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
415     pOut += CMPLX_DIM;
416 
417     /*
418      * move to next B column
419      */
420     pInB = pInB + CMPLX_DIM;
421 
422     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
423 
424     vecA = vldrwq_f32(pInA0);
425     acc0 = vcmulq(vecA, vecB);
426     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
427 
428     vecA = vldrwq_f32(pInA1);
429     acc1 = vcmulq(vecA, vecB);
430     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
431 
432     vecA = vldrwq_f32(pInA2);
433     acc2 = vcmulq(vecA, vecB);
434     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
435 
436     vecA = vldrwq_f32(pInA3);
437     acc3 = vcmulq(vecA, vecB);
438     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
439 
440     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
441 
442     vecA = vldrwq_f32(&pInA0[4]);
443     acc0 = vcmlaq(acc0, vecA, vecB);
444     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
445 
446     vecA = vldrwq_f32(&pInA1[4]);
447     acc1 = vcmlaq(acc1, vecA, vecB);
448     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
449 
450     vecA = vldrwq_f32(&pInA2[4]);
451     acc2 = vcmlaq(acc2, vecA, vecB);
452     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
453 
454     vecA = vldrwq_f32(&pInA3[4]);
455     acc3 = vcmlaq(acc3, vecA, vecB);
456     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
457 
458     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
459     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
460     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
461     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
462     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
463     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
464     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
465     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
466     pOut += CMPLX_DIM;
467 
468     /*
469      * move to next B column
470      */
471     pInB = pInB + CMPLX_DIM;
472 
473     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
474 
475     vecA = vldrwq_f32(pInA0);
476     acc0 = vcmulq(vecA, vecB);
477     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
478 
479     vecA = vldrwq_f32(pInA1);
480     acc1 = vcmulq(vecA, vecB);
481     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
482 
483     vecA = vldrwq_f32(pInA2);
484     acc2 = vcmulq(vecA, vecB);
485     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
486 
487     vecA = vldrwq_f32(pInA3);
488     acc3 = vcmulq(vecA, vecB);
489     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
490 
491     vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
492 
493     vecA = vldrwq_f32(&pInA0[4]);
494     acc0 = vcmlaq(acc0, vecA, vecB);
495     acc0 = vcmlaq_rot90(acc0, vecA, vecB);
496 
497     vecA = vldrwq_f32(&pInA1[4]);
498     acc1 = vcmlaq(acc1, vecA, vecB);
499     acc1 = vcmlaq_rot90(acc1, vecA, vecB);
500 
501     vecA = vldrwq_f32(&pInA2[4]);
502     acc2 = vcmlaq(acc2, vecA, vecB);
503     acc2 = vcmlaq_rot90(acc2, vecA, vecB);
504 
505     vecA = vldrwq_f32(&pInA3[4]);
506     acc3 = vcmlaq(acc3, vecA, vecB);
507     acc3 = vcmlaq_rot90(acc3, vecA, vecB);
508 
509     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
510     pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
511     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
512     pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
513     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
514     pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
515     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
516     pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
517     /*
518      * Return to application
519      */
520     return (ARM_MATH_SUCCESS);
521 }
522 
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)523 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
524   const arm_matrix_instance_f32 * pSrcA,
525   const arm_matrix_instance_f32 * pSrcB,
526   arm_matrix_instance_f32 * pDst)
527 {
528     float32_t const *pInB = (float32_t const *) pSrcB->pData;   /* input data matrix pointer B */
529     float32_t const *pInA = (float32_t const *) pSrcA->pData;   /* input data matrix pointer A */
530     float32_t *pOut = pDst->pData;  /* output data matrix pointer */
531     float32_t *px;              /* Temporary output data matrix pointer */
532     uint16_t  numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
533     uint16_t  numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
534     uint16_t  numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
535     uint16_t  col, i = 0U, row = numRowsA;  /* loop counters */
536     arm_status status;          /* status of matrix multiplication */
537     uint32x4_t vecOffs, vecColBOffs;
538     uint32_t  blkCnt, rowCnt;           /* loop counters */
539 
540   #ifdef ARM_MATH_MATRIX_CHECK
541 
542 
543   /* Check for matrix mismatch condition */
544   if ((pSrcA->numCols != pSrcB->numRows) ||
545      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
546   {
547 
548     /* Set status as ARM_MATH_SIZE_MISMATCH */
549     status = ARM_MATH_SIZE_MISMATCH;
550   }
551   else
552 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
553 
554   {
555     /*
556      * small squared matrix specialized routines
557      */
558     if (numRowsA == numColsB && numColsB == numColsA)
559     {
560         if (numRowsA == 1)
561         {
562             pOut[0] = pInA[0] * pInB[0] - pInA[1] * pInB[1];
563             pOut[1] = pInA[0] * pInB[1] + pInA[1] * pInB[0];
564             return (ARM_MATH_SUCCESS);
565         }
566         else if (numRowsA == 2)
567             return arm_mat_cmplx_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
568         else if (numRowsA == 3)
569             return arm_mat_cmplx_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
570         else if (numRowsA == 4)
571             return arm_mat_cmplx_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
572     }
573 
574     vecColBOffs[0] = 0;
575     vecColBOffs[1] = 1;
576     vecColBOffs[2] = numColsB * CMPLX_DIM;
577     vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
578 
579     /*
580      * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
581      */
582 
583     /*
584      * row loop
585      */
586     rowCnt = row >> 2;
587     while (rowCnt > 0u)
588     {
589         /*
590          * Output pointer is set to starting address of the row being processed
591          */
592         px = pOut + i * CMPLX_DIM;
593         i = i + 4 * numColsB;
594         /*
595          * For every row wise process, the column loop counter is to be initiated
596          */
597         col = numColsB;
598         /*
599          * For every row wise process, the pInB pointer is set
600          * to the starting address of the pSrcB data
601          */
602         pInB = (float32_t const *) pSrcB->pData;
603         /*
604          * column loop
605          */
606         while (col > 0u)
607         {
608             /*
609              * generate 4 columns elements
610              */
611             /*
612              * Matrix A columns number of MAC operations are to be performed
613              */
614 
615             float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
616             float32_t const *pInA0 = pInA;
617             float32_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
618             float32_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
619             float32_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
620             f32x4_t acc0, acc1, acc2, acc3;
621 
622             acc0 = vdupq_n_f32(0.0f);
623             acc1 = vdupq_n_f32(0.0f);
624             acc2 = vdupq_n_f32(0.0f);
625             acc3 = vdupq_n_f32(0.0f);
626 
627             pSrcA0Vec = (float32_t const *) pInA0;
628             pSrcA1Vec = (float32_t const *) pInA1;
629             pSrcA2Vec = (float32_t const *) pInA2;
630             pSrcA3Vec = (float32_t const *) pInA3;
631 
632             vecOffs = vecColBOffs;
633 
634             /*
635              * process 1 x 4 block output
636              */
637             blkCnt = (numColsA * CMPLX_DIM) >> 2;
638             while (blkCnt > 0U)
639             {
640                 f32x4_t vecB, vecA;
641 
642                 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
643                 /*
644                  * move Matrix B read offsets, 4 rows down
645                  */
646                 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
647 
648                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 4;
649                 acc0 = vcmlaq(acc0, vecA, vecB);
650                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
651                 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
652                 acc1 = vcmlaq(acc1, vecA, vecB);
653                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
654                 vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 4;
655                 acc2 = vcmlaq(acc2, vecA, vecB);
656                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
657                 vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 4;
658                 acc3 = vcmlaq(acc3, vecA, vecB);
659                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
660 
661                 blkCnt--;
662             }
663 
664 
665             /*
666              * tail
667              * (will be merged thru tail predication)
668              */
669             blkCnt = (numColsA * CMPLX_DIM) & 3;
670             if (blkCnt > 0U)
671             {
672                 mve_pred16_t p0 = vctp32q(blkCnt);
673                 f32x4_t vecB, vecA;
674 
675                 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
676                 /*
677                  * move Matrix B read offsets, 4 rows down
678                  */
679                 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
680 
681                 vecA = vld1q(pSrcA0Vec);
682                 acc0 = vcmlaq(acc0, vecA, vecB);
683                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
684                 vecA = vld1q(pSrcA1Vec);
685                 acc1 = vcmlaq(acc1, vecA, vecB);
686                 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
687                 vecA = vld1q(pSrcA2Vec);
688                 acc2 = vcmlaq(acc2, vecA, vecB);
689                 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
690                 vecA = vld1q(pSrcA3Vec);
691                 acc3 = vcmlaq(acc3, vecA, vecB);
692                 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
693 
694             }
695 
696             px[0 * CMPLX_DIM * numColsB + 0] = acc0[0] + acc0[2];
697             px[0 * CMPLX_DIM * numColsB + 1] = acc0[1] + acc0[3];
698             px[1 * CMPLX_DIM * numColsB + 0] = acc1[0] + acc1[2];
699             px[1 * CMPLX_DIM * numColsB + 1] = acc1[1] + acc1[3];
700             px[2 * CMPLX_DIM * numColsB + 0] = acc2[0] + acc2[2];
701             px[2 * CMPLX_DIM * numColsB + 1] = acc2[1] + acc2[3];
702             px[3 * CMPLX_DIM * numColsB + 0] = acc3[0] + acc3[2];
703             px[3 * CMPLX_DIM * numColsB + 1] = acc3[1] + acc3[3];
704             px += CMPLX_DIM;
705             /*
706              * Decrement the column loop counter
707              */
708             col--;
709             /*
710              * Update the pointer pInB to point to the  starting address of the next column
711              */
712             pInB = (float32_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
713         }
714 
715         /*
716          * Update the pointer pInA to point to the  starting address of the next row
717          */
718         pInA += (numColsA * 4) * CMPLX_DIM;
719         /*
720          * Decrement the row loop counter
721          */
722         rowCnt --;
723 
724     }
725 
726     rowCnt = row & 3;
727     while (rowCnt > 0u)
728     {
729            /*
730          * Output pointer is set to starting address of the row being processed
731          */
732         px = pOut + i * CMPLX_DIM;
733         i = i + numColsB;
734         /*
735          * For every row wise process, the column loop counter is to be initiated
736          */
737         col = numColsB;
738         /*
739          * For every row wise process, the pInB pointer is set
740          * to the starting address of the pSrcB data
741          */
742         pInB = (float32_t const *) pSrcB->pData;
743         /*
744          * column loop
745          */
746         while (col > 0u)
747         {
748             /*
749              * generate 4 columns elements
750              */
751             /*
752              * Matrix A columns number of MAC operations are to be performed
753              */
754 
755             float32_t const *pSrcA0Vec;
756             float32_t const *pInA0 = pInA;
757             f32x4_t acc0;
758 
759             acc0 = vdupq_n_f32(0.0f);
760 
761             pSrcA0Vec = (float32_t const *) pInA0;
762 
763             vecOffs = vecColBOffs;
764 
765             /*
766              * process 1 x 4 block output
767              */
768             blkCnt = (numColsA * CMPLX_DIM) >> 2;
769             while (blkCnt > 0U)
770             {
771                 f32x4_t vecB, vecA;
772 
773                 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
774                 /*
775                  * move Matrix B read offsets, 4 rows down
776                  */
777                 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
778 
779                 vecA = vld1q(pSrcA0Vec);
780                 pSrcA0Vec += 4;
781                 acc0 = vcmlaq(acc0, vecA, vecB);
782                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
783 
784 
785                 blkCnt--;
786             }
787 
788 
789             /*
790              * tail
791              */
792             blkCnt = (numColsA * CMPLX_DIM) & 3;
793             if (blkCnt > 0U)
794             {
795                 mve_pred16_t p0 = vctp32q(blkCnt);
796                 f32x4_t vecB, vecA;
797 
798                 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
799 
800                 vecA = vld1q(pSrcA0Vec);
801                 acc0 = vcmlaq(acc0, vecA, vecB);
802                 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
803 
804             }
805 
806             px[0] = acc0[0] + acc0[2];
807             px[1] = acc0[1] + acc0[3];
808 
809             px += CMPLX_DIM;
810             /*
811              * Decrement the column loop counter
812              */
813             col--;
814             /*
815              * Update the pointer pInB to point to the  starting address of the next column
816              */
817             pInB = (float32_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
818         }
819 
820         /*
821          * Update the pointer pInA to point to the  starting address of the next row
822          */
823         pInA += numColsA  * CMPLX_DIM;
824         rowCnt--;
825     }
826 
827 
828       /* Set status as ARM_MATH_SUCCESS */
829     status = ARM_MATH_SUCCESS;
830   }
831 
832   /* Return to application */
833   return (status);
834 
835 }
836 
837 #else
838 #if defined(ARM_MATH_NEON)
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)839 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
840   const arm_matrix_instance_f32 * pSrcA,
841   const arm_matrix_instance_f32 * pSrcB,
842   arm_matrix_instance_f32 * pDst)
843 {
844   float32_t *pIn1 = pSrcA->pData;                /* input data matrix pointer A */
845   float32_t *pIn2 = pSrcB->pData;                /* input data matrix pointer B */
846   float32_t *pInA = pSrcA->pData;                /* input data matrix pointer A  */
847   float32_t *pOut = pDst->pData;                 /* output data matrix pointer */
848   float32_t *px;                                 /* Temporary output data matrix pointer */
849   uint16_t numRowsA = pSrcA->numRows;            /* number of rows of input matrix A */
850   uint16_t numColsB = pSrcB->numCols;            /* number of columns of input matrix B */
851   uint16_t numColsA = pSrcA->numCols;            /* number of columns of input matrix A */
852   float32_t sumReal1, sumImag1;                  /* accumulator */
853   float32_t a1, a1B,b1, b1B, c1, d1;
854   float32_t sumReal2, sumImag2;                  /* accumulator */
855 
856 
857   float32x4x2_t a0V, a1V;
858   float32x4_t accR0,accI0, accR1,accI1,tempR, tempI;
859   float32x2_t accum = vdup_n_f32(0);
860   float32_t *pIn1B = pSrcA->pData;
861 
862   uint16_t col, i = 0U, j, rowCnt, row = numRowsA, colCnt;      /* loop counters */
863   arm_status status;                             /* status of matrix multiplication */
864   float32_t sumReal1B, sumImag1B;
865   float32_t sumReal2B, sumImag2B;
866   float32_t *pxB;
867 
868 #ifdef ARM_MATH_MATRIX_CHECK
869 
870 
871   /* Check for matrix mismatch condition */
872   if ((pSrcA->numCols != pSrcB->numRows) ||
873      (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
874   {
875 
876     /* Set status as ARM_MATH_SIZE_MISMATCH */
877     status = ARM_MATH_SIZE_MISMATCH;
878   }
879   else
880 #endif /*      #ifdef ARM_MATH_MATRIX_CHECK    */
881 
882   {
883     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
884 
885     rowCnt = row >> 1;
886 
887     /* Row loop */
888     while (rowCnt > 0U)
889     {
890       /* Output pointer is set to starting address of the row being processed */
891       px = pOut + 2 * i;
892       pxB = px + 2 * numColsB;
893 
894       /* For every row wise process, the column loop counter is to be initiated */
895       col = numColsB;
896 
897       /* For every row wise process, the pIn2 pointer is set
898        ** to the starting address of the pSrcB data */
899       pIn2 = pSrcB->pData;
900 
901       j = 0U;
902 
903       /* Column loop */
904       while (col > 0U)
905       {
906         /* Set the variable sum, that acts as accumulator, to zero */
907         sumReal1 = 0.0f;
908         sumImag1 = 0.0f;
909         sumReal1B = 0.0f;
910         sumImag1B = 0.0f;
911 
912         sumReal2 = 0.0f;
913         sumImag2 = 0.0f;
914         sumReal2B = 0.0f;
915         sumImag2B = 0.0f;
916 
917         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
918         pIn1 = pInA;
919         pIn1B = pIn1 + 2*numColsA;
920 
921         accR0 = vdupq_n_f32(0.0);
922         accI0 = vdupq_n_f32(0.0);
923         accR1 = vdupq_n_f32(0.0);
924         accI1 = vdupq_n_f32(0.0);
925 
926         /* Compute 4 MACs simultaneously. */
927         colCnt = numColsA >> 2;
928 
929         /* Matrix multiplication */
930         while (colCnt > 0U)
931         {
932           /* Reading real part of complex matrix A */
933           a0V = vld2q_f32(pIn1);  // load & separate real/imag pSrcA (de-interleave 2)
934           a1V = vld2q_f32(pIn1B);  // load & separate real/imag pSrcA (de-interleave 2)
935 
936           pIn1 += 8;
937           pIn1B += 8;
938 
939           tempR = vsetq_lane_f32(*pIn2,tempR,0);
940           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,0);
941           pIn2 += 2 * numColsB;
942 
943 
944           tempR = vsetq_lane_f32(*pIn2,tempR,1);
945           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,1);
946           pIn2 += 2 * numColsB;
947 
948           tempR = vsetq_lane_f32(*pIn2,tempR,2);
949           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,2);
950           pIn2 += 2 * numColsB;
951 
952           tempR = vsetq_lane_f32(*pIn2,tempR,3);
953           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,3);
954           pIn2 += 2 * numColsB;
955 
956           accR0 = vmlaq_f32(accR0,a0V.val[0],tempR);
957           accR0 = vmlsq_f32(accR0,a0V.val[1],tempI);
958 
959           accI0 = vmlaq_f32(accI0,a0V.val[1],tempR);
960           accI0 = vmlaq_f32(accI0,a0V.val[0],tempI);
961 
962           accR1 = vmlaq_f32(accR1,a1V.val[0],tempR);
963           accR1 = vmlsq_f32(accR1,a1V.val[1],tempI);
964 
965           accI1 = vmlaq_f32(accI1,a1V.val[1],tempR);
966           accI1 = vmlaq_f32(accI1,a1V.val[0],tempI);
967 
968           /* Decrement the loop count */
969           colCnt--;
970         }
971 
972         accum = vpadd_f32(vget_low_f32(accR0), vget_high_f32(accR0));
973         sumReal1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
974 
975         accum = vpadd_f32(vget_low_f32(accI0), vget_high_f32(accI0));
976         sumImag1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
977 
978         accum = vpadd_f32(vget_low_f32(accR1), vget_high_f32(accR1));
979         sumReal1B += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
980 
981         accum = vpadd_f32(vget_low_f32(accI1), vget_high_f32(accI1));
982         sumImag1B += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
983 
984         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
985          ** No loop unrolling is used. */
986         colCnt = numColsA & 3;
987 
988         while (colCnt > 0U)
989         {
990           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
991           a1 = *pIn1;
992           a1B = *pIn1B;
993 
994           c1 = *pIn2;
995 
996           b1 = *(pIn1 + 1U);
997           b1B = *(pIn1B + 1U);
998 
999           d1 = *(pIn2 + 1U);
1000 
1001           sumReal1 += a1 * c1;
1002           sumImag1 += b1 * c1;
1003 
1004           sumReal1B += a1B * c1;
1005           sumImag1B += b1B * c1;
1006 
1007           pIn1 += 2U;
1008           pIn1B += 2U;
1009           pIn2 += 2 * numColsB;
1010 
1011           sumReal2 -= b1 * d1;
1012           sumImag2 += a1 * d1;
1013 
1014           sumReal2B -= b1B * d1;
1015           sumImag2B += a1B * d1;
1016 
1017           /* Decrement the loop counter */
1018           colCnt--;
1019         }
1020 
1021         sumReal1 += sumReal2;
1022         sumImag1 += sumImag2;
1023 
1024         sumReal1B += sumReal2B;
1025         sumImag1B += sumImag2B;
1026 
1027         /* Store the result in the destination buffer */
1028         *px++ = sumReal1;
1029         *px++ = sumImag1;
1030         *pxB++ = sumReal1B;
1031         *pxB++ = sumImag1B;
1032 
1033         /* Update the pointer pIn2 to point to the  starting address of the next column */
1034         j++;
1035         pIn2 = pSrcB->pData + 2U * j;
1036 
1037         /* Decrement the column loop counter */
1038         col--;
1039       }
1040 
1041       /* Update the pointer pInA to point to the  starting address of the next 2 row */
1042       i = i + 2*numColsB;
1043       pInA = pInA + 4 * numColsA;
1044 
1045       /* Decrement the row loop counter */
1046       rowCnt--;
1047     }
1048 
1049     rowCnt = row & 1;
1050     while (rowCnt > 0U)
1051     {
1052       /* Output pointer is set to starting address of the row being processed */
1053       px = pOut + 2 * i;
1054 
1055       /* For every row wise process, the column loop counter is to be initiated */
1056       col = numColsB;
1057 
1058       /* For every row wise process, the pIn2 pointer is set
1059        ** to the starting address of the pSrcB data */
1060       pIn2 = pSrcB->pData;
1061 
1062       j = 0U;
1063 
1064       /* Column loop */
1065       while (col > 0U)
1066       {
1067         /* Set the variable sum, that acts as accumulator, to zero */
1068         sumReal1 = 0.0f;
1069         sumImag1 = 0.0f;
1070 
1071         sumReal2 = 0.0f;
1072         sumImag2 = 0.0f;
1073 
1074         /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
1075         pIn1 = pInA;
1076 
1077         accR0 = vdupq_n_f32(0.0);
1078         accI0 = vdupq_n_f32(0.0);
1079 
1080         /* Compute 4 MACs simultaneously. */
1081         colCnt = numColsA >> 2;
1082 
1083         /* Matrix multiplication */
1084         while (colCnt > 0U)
1085         {
1086           /* Reading real part of complex matrix A */
1087           a0V = vld2q_f32(pIn1);  // load & separate real/imag pSrcA (de-interleave 2)
1088           pIn1 += 8;
1089 
1090           tempR = vsetq_lane_f32(*pIn2,tempR,0);
1091           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,0);
1092           pIn2 += 2 * numColsB;
1093 
1094           tempR = vsetq_lane_f32(*pIn2,tempR,1);
1095           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,1);
1096           pIn2 += 2 * numColsB;
1097 
1098           tempR = vsetq_lane_f32(*pIn2,tempR,2);
1099           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,2);
1100           pIn2 += 2 * numColsB;
1101 
1102           tempR = vsetq_lane_f32(*pIn2,tempR,3);
1103           tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,3);
1104           pIn2 += 2 * numColsB;
1105 
1106           accR0 = vmlaq_f32(accR0,a0V.val[0],tempR);
1107           accR0 = vmlsq_f32(accR0,a0V.val[1],tempI);
1108 
1109           accI0 = vmlaq_f32(accI0,a0V.val[1],tempR);
1110           accI0 = vmlaq_f32(accI0,a0V.val[0],tempI);
1111 
1112           /* Decrement the loop count */
1113           colCnt--;
1114         }
1115 
1116         accum = vpadd_f32(vget_low_f32(accR0), vget_high_f32(accR0));
1117         sumReal1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
1118 
1119         accum = vpadd_f32(vget_low_f32(accI0), vget_high_f32(accI0));
1120         sumImag1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
1121 
1122         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
1123          ** No loop unrolling is used. */
1124         colCnt = numColsA & 3;
1125 
1126         while (colCnt > 0U)
1127         {
1128           /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
1129           a1 = *pIn1;
1130           c1 = *pIn2;
1131 
1132           b1 = *(pIn1 + 1U);
1133           d1 = *(pIn2 + 1U);
1134 
1135           sumReal1 += a1 * c1;
1136           sumImag1 += b1 * c1;
1137 
1138           pIn1 += 2U;
1139           pIn2 += 2 * numColsB;
1140 
1141           sumReal2 -= b1 * d1;
1142           sumImag2 += a1 * d1;
1143 
1144           /* Decrement the loop counter */
1145           colCnt--;
1146         }
1147 
1148         sumReal1 += sumReal2;
1149         sumImag1 += sumImag2;
1150 
1151         /* Store the result in the destination buffer */
1152         *px++ = sumReal1;
1153         *px++ = sumImag1;
1154 
1155         /* Update the pointer pIn2 to point to the  starting address of the next column */
1156         j++;
1157         pIn2 = pSrcB->pData + 2U * j;
1158 
1159         /* Decrement the column loop counter */
1160         col--;
1161 
1162       }
1163 
1164       /* Update the pointer pInA to point to the  starting address of the next row */
1165       i = i + numColsB;
1166       pInA = pInA + 2 * numColsA;
1167 
1168       /* Decrement the row loop counter */
1169       rowCnt--;
1170 
1171     }
1172 
1173     /* Set status as ARM_MATH_SUCCESS */
1174     status = ARM_MATH_SUCCESS;
1175   }
1176 
1177   /* Return to application */
1178   return (status);
1179 }
1180 #else
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)1181 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
1182   const arm_matrix_instance_f32 * pSrcA,
1183   const arm_matrix_instance_f32 * pSrcB,
1184         arm_matrix_instance_f32 * pDst)
1185 {
1186   float32_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
1187   float32_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
1188   float32_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
1189   float32_t *pOut = pDst->pData;                 /* Output data matrix pointer */
1190   float32_t *px;                                 /* Temporary output data matrix pointer */
1191   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
1192   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
1193   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
1194   float32_t sumReal, sumImag;                    /* Accumulator */
1195   float32_t a1, b1, c1, d1;
1196   uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
1197   arm_status status;                             /* status of matrix multiplication */
1198 
1199 #if defined (ARM_MATH_LOOPUNROLL)
1200   float32_t a0, b0, c0, d0;
1201 #endif
1202 
1203 #ifdef ARM_MATH_MATRIX_CHECK
1204 
1205   /* Check for matrix mismatch condition */
1206   if ((pSrcA->numCols != pSrcB->numRows) ||
1207       (pSrcA->numRows != pDst->numRows)  ||
1208       (pSrcB->numCols != pDst->numCols)    )
1209   {
1210     /* Set status as ARM_MATH_SIZE_MISMATCH */
1211     status = ARM_MATH_SIZE_MISMATCH;
1212   }
1213   else
1214 
1215 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
1216 
1217   {
1218     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
1219     /* row loop */
1220     do
1221     {
1222       /* Output pointer is set to starting address of the row being processed */
1223       px = pOut + 2 * i;
1224 
1225       /* For every row wise process, the column loop counter is to be initiated */
1226       col = numColsB;
1227 
1228       /* For every row wise process, the pIn2 pointer is set
1229        ** to the starting address of the pSrcB data */
1230       pIn2 = pSrcB->pData;
1231 
1232       j = 0U;
1233 
1234       /* column loop */
1235       do
1236       {
1237         /* Set the variable sum, that acts as accumulator, to zero */
1238         sumReal = 0.0f;
1239         sumImag = 0.0f;
1240 
1241         /* Initiate pointer pIn1 to point to starting address of column being processed */
1242         pIn1 = pInA;
1243 
1244 #if defined (ARM_MATH_LOOPUNROLL)
1245 
1246         /* Apply loop unrolling and compute 4 MACs simultaneously. */
1247         colCnt = numColsA >> 2U;
1248 
1249         /* matrix multiplication */
1250         while (colCnt > 0U)
1251         {
1252 
1253           /* Reading real part of complex matrix A */
1254           a0 = *pIn1;
1255 
1256           /* Reading real part of complex matrix B */
1257           c0 = *pIn2;
1258 
1259           /* Reading imaginary part of complex matrix A */
1260           b0 = *(pIn1 + 1U);
1261 
1262           /* Reading imaginary part of complex matrix B */
1263           d0 = *(pIn2 + 1U);
1264 
1265           /* Multiply and Accumlates */
1266           sumReal += a0 * c0;
1267           sumImag += b0 * c0;
1268 
1269           /* update pointers */
1270           pIn1 += 2U;
1271           pIn2 += 2 * numColsB;
1272 
1273           /* Multiply and Accumlates */
1274           sumReal -= b0 * d0;
1275           sumImag += a0 * d0;
1276 
1277           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1278 
1279           /* read real and imag values from pSrcA and pSrcB buffer */
1280           a1 = *(pIn1     );
1281           c1 = *(pIn2     );
1282           b1 = *(pIn1 + 1U);
1283           d1 = *(pIn2 + 1U);
1284 
1285           /* Multiply and Accumlates */
1286           sumReal += a1 * c1;
1287           sumImag += b1 * c1;
1288 
1289           /* update pointers */
1290           pIn1 += 2U;
1291           pIn2 += 2 * numColsB;
1292 
1293           /* Multiply and Accumlates */
1294           sumReal -= b1 * d1;
1295           sumImag += a1 * d1;
1296 
1297           a0 = *(pIn1     );
1298           c0 = *(pIn2     );
1299           b0 = *(pIn1 + 1U);
1300           d0 = *(pIn2 + 1U);
1301 
1302           /* Multiply and Accumlates */
1303           sumReal += a0 * c0;
1304           sumImag += b0 * c0;
1305 
1306           /* update pointers */
1307           pIn1 += 2U;
1308           pIn2 += 2 * numColsB;
1309 
1310           /* Multiply and Accumlates */
1311           sumReal -= b0 * d0;
1312           sumImag += a0 * d0;
1313 
1314           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1315 
1316           a1 = *(pIn1     );
1317           c1 = *(pIn2     );
1318           b1 = *(pIn1 + 1U);
1319           d1 = *(pIn2 + 1U);
1320 
1321           /* Multiply and Accumlates */
1322           sumReal += a1 * c1;
1323           sumImag += b1 * c1;
1324 
1325           /* update pointers */
1326           pIn1 += 2U;
1327           pIn2 += 2 * numColsB;
1328 
1329           /* Multiply and Accumlates */
1330           sumReal -= b1 * d1;
1331           sumImag += a1 * d1;
1332 
1333           /* Decrement loop count */
1334           colCnt--;
1335         }
1336 
1337         /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
1338          ** No loop unrolling is used. */
1339         colCnt = numColsA % 0x4U;
1340 
1341 #else
1342 
1343         /* Initialize blkCnt with number of samples */
1344         colCnt = numColsA;
1345 
1346 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
1347 
1348         while (colCnt > 0U)
1349         {
1350           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1351           a1 = *(pIn1     );
1352           c1 = *(pIn2     );
1353           b1 = *(pIn1 + 1U);
1354           d1 = *(pIn2 + 1U);
1355 
1356           /* Multiply and Accumlates */
1357           sumReal += a1 * c1;
1358           sumImag += b1 * c1;
1359 
1360           /* update pointers */
1361           pIn1 += 2U;
1362           pIn2 += 2 * numColsB;
1363 
1364           /* Multiply and Accumlates */
1365           sumReal -= b1 * d1;
1366           sumImag += a1 * d1;
1367 
1368           /* Decrement loop counter */
1369           colCnt--;
1370         }
1371 
1372         /* Store result in destination buffer */
1373         *px++ = sumReal;
1374         *px++ = sumImag;
1375 
1376         /* Update pointer pIn2 to point to starting address of next column */
1377         j++;
1378         pIn2 = pSrcB->pData + 2U * j;
1379 
1380         /* Decrement column loop counter */
1381         col--;
1382 
1383       } while (col > 0U);
1384 
1385       /* Update pointer pInA to point to starting address of next row */
1386       i = i + numColsB;
1387       pInA = pInA + 2 * numColsA;
1388 
1389       /* Decrement row loop counter */
1390       row--;
1391 
1392     } while (row > 0U);
1393 
1394     /* Set status as ARM_MATH_SUCCESS */
1395     status = ARM_MATH_SUCCESS;
1396   }
1397 
1398   /* Return to application */
1399   return (status);
1400 }
1401 
1402 #endif /* #if defined(ARM_MATH_NEON) */
1403 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
1404 
1405 /**
1406   @} end of MatrixMult group
1407  */
1408