1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_cmplx_dot_prod_f16.c
4  * Description:  Floating-point complex dot product
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/complex_math_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 
34 /**
35   @ingroup groupCmplxMath
36  */
37 
38 /**
39   @defgroup cmplx_dot_prod Complex Dot Product
40 
41   Computes the dot product of two complex vectors.
42   The vectors are multiplied element-by-element and then summed.
43 
44   The <code>pSrcA</code> points to the first complex input vector and
45   <code>pSrcB</code> points to the second complex input vector.
46   <code>numSamples</code> specifies the number of complex samples
47   and the data in each array is stored in an interleaved fashion
48   (real, imag, real, imag, ...).
49   Each array has a total of <code>2*numSamples</code> values.
50 
51   The underlying algorithm is used:
52 
53   <pre>
54   realResult = 0;
55   imagResult = 0;
56   for (n = 0; n < numSamples; n++) {
57       realResult += pSrcA[(2*n)+0] * pSrcB[(2*n)+0] - pSrcA[(2*n)+1] * pSrcB[(2*n)+1];
58       imagResult += pSrcA[(2*n)+0] * pSrcB[(2*n)+1] + pSrcA[(2*n)+1] * pSrcB[(2*n)+0];
59   }
60   </pre>
61 
62   There are separate functions for floating-point, Q15, and Q31 data types.
63  */
64 
65 /**
66   @addtogroup cmplx_dot_prod
67   @{
68  */
69 
70 /**
71   @brief         Floating-point complex dot product.
72   @param[in]     pSrcA       points to the first input vector
73   @param[in]     pSrcB       points to the second input vector
74   @param[in]     numSamples  number of samples in each vector
75   @param[out]    realResult  real part of the result returned here
76   @param[out]    imagResult  imaginary part of the result returned here
77   @return        none
78  */
79 
80 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
81 
82 #include "arm_helium_utils.h"
83 
arm_cmplx_dot_prod_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t numSamples,float16_t * realResult,float16_t * imagResult)84 void arm_cmplx_dot_prod_f16(
85     const float16_t * pSrcA,
86     const float16_t * pSrcB,
87     uint32_t numSamples,
88     float16_t * realResult,
89     float16_t * imagResult)
90 {
91     int32_t         blkCnt;
92     float16_t       real_sum, imag_sum;
93     f16x8_t         vecSrcA, vecSrcB;
94     f16x8_t         vec_acc = vdupq_n_f16(0.0f16);
95     f16x8_t         vecSrcC, vecSrcD;
96 
97     blkCnt = (numSamples >> 3);
98     blkCnt -= 1;
99     if (blkCnt > 0) {
100         /* should give more freedom to generate stall free code */
101         vecSrcA = vld1q( pSrcA);
102         vecSrcB = vld1q( pSrcB);
103         pSrcA += 8;
104         pSrcB += 8;
105 
106         while (blkCnt > 0) {
107             vec_acc = vcmlaq(vec_acc, vecSrcA, vecSrcB);
108             vecSrcC = vld1q(pSrcA);
109             pSrcA += 8;
110 
111             vec_acc = vcmlaq_rot90(vec_acc, vecSrcA, vecSrcB);
112             vecSrcD = vld1q(pSrcB);
113             pSrcB += 8;
114 
115             vec_acc = vcmlaq(vec_acc, vecSrcC, vecSrcD);
116             vecSrcA = vld1q(pSrcA);
117             pSrcA += 8;
118 
119             vec_acc = vcmlaq_rot90(vec_acc, vecSrcC, vecSrcD);
120             vecSrcB = vld1q(pSrcB);
121             pSrcB += 8;
122             /*
123              * Decrement the blockSize loop counter
124              */
125             blkCnt--;
126         }
127 
128         /* process last elements out of the loop avoid the armclang breaking the SW pipeline */
129         vec_acc = vcmlaq(vec_acc, vecSrcA, vecSrcB);
130         vecSrcC = vld1q(pSrcA);
131 
132         vec_acc = vcmlaq_rot90(vec_acc, vecSrcA, vecSrcB);
133         vecSrcD = vld1q(pSrcB);
134 
135         vec_acc = vcmlaq(vec_acc, vecSrcC, vecSrcD);
136         vec_acc = vcmlaq_rot90(vec_acc, vecSrcC, vecSrcD);
137 
138         /*
139          * tail
140          */
141         blkCnt = CMPLX_DIM * (numSamples & 7);
142         while (blkCnt > 0) {
143             mve_pred16_t    p = vctp16q(blkCnt);
144             pSrcA += 8;
145             pSrcB += 8;
146 
147             vecSrcA = vldrhq_z_f16(pSrcA, p);
148             vecSrcB = vldrhq_z_f16(pSrcB, p);
149             vec_acc = vcmlaq_m(vec_acc, vecSrcA, vecSrcB, p);
150             vec_acc = vcmlaq_rot90_m(vec_acc, vecSrcA, vecSrcB, p);
151 
152             blkCnt -= 8;
153         }
154     } else {
155         /* small vector */
156         blkCnt = numSamples * CMPLX_DIM;
157         vec_acc = vdupq_n_f16(0.0f16);
158 
159         do {
160             mve_pred16_t    p = vctp16q(blkCnt);
161 
162             vecSrcA = vldrhq_z_f16(pSrcA, p);
163             vecSrcB = vldrhq_z_f16(pSrcB, p);
164 
165             vec_acc = vcmlaq_m(vec_acc, vecSrcA, vecSrcB, p);
166             vec_acc = vcmlaq_rot90_m(vec_acc, vecSrcA, vecSrcB, p);
167 
168             /*
169              * Decrement the blkCnt loop counter
170              * Advance vector source and destination pointers
171              */
172             pSrcA += 8;
173             pSrcB += 8;
174             blkCnt -= 8;
175         }
176         while (blkCnt > 0);
177     }
178 
179     /* Sum the partial parts */
180     mve_cmplx_sum_intra_r_i_f16(vec_acc, real_sum, imag_sum);
181 
182     /*
183      * Store the real and imaginary results in the destination buffers
184      */
185     *realResult = real_sum;
186     *imagResult = imag_sum;
187 }
188 
189 #else
arm_cmplx_dot_prod_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t numSamples,float16_t * realResult,float16_t * imagResult)190 void arm_cmplx_dot_prod_f16(
191   const float16_t * pSrcA,
192   const float16_t * pSrcB,
193         uint32_t numSamples,
194         float16_t * realResult,
195         float16_t * imagResult)
196 {
197         uint32_t blkCnt;                               /* Loop counter */
198         _Float16 real_sum = 0.0f, imag_sum = 0.0f;    /* Temporary result variables */
199         _Float16 a0,b0,c0,d0;
200 
201 #if defined (ARM_MATH_LOOPUNROLL) && !defined(ARM_MATH_AUTOVECTORIZE)
202 
203   /* Loop unrolling: Compute 4 outputs at a time */
204   blkCnt = numSamples >> 2U;
205 
206   while (blkCnt > 0U)
207   {
208     a0 = *pSrcA++;
209     b0 = *pSrcA++;
210     c0 = *pSrcB++;
211     d0 = *pSrcB++;
212 
213     real_sum += a0 * c0;
214     imag_sum += a0 * d0;
215     real_sum -= b0 * d0;
216     imag_sum += b0 * c0;
217 
218     a0 = *pSrcA++;
219     b0 = *pSrcA++;
220     c0 = *pSrcB++;
221     d0 = *pSrcB++;
222 
223     real_sum += a0 * c0;
224     imag_sum += a0 * d0;
225     real_sum -= b0 * d0;
226     imag_sum += b0 * c0;
227 
228     a0 = *pSrcA++;
229     b0 = *pSrcA++;
230     c0 = *pSrcB++;
231     d0 = *pSrcB++;
232 
233     real_sum += a0 * c0;
234     imag_sum += a0 * d0;
235     real_sum -= b0 * d0;
236     imag_sum += b0 * c0;
237 
238     a0 = *pSrcA++;
239     b0 = *pSrcA++;
240     c0 = *pSrcB++;
241     d0 = *pSrcB++;
242 
243     real_sum += a0 * c0;
244     imag_sum += a0 * d0;
245     real_sum -= b0 * d0;
246     imag_sum += b0 * c0;
247 
248     /* Decrement loop counter */
249     blkCnt--;
250   }
251 
252   /* Loop unrolling: Compute remaining outputs */
253   blkCnt = numSamples % 0x4U;
254 
255 #else
256 
257   /* Initialize blkCnt with number of samples */
258   blkCnt = numSamples;
259 
260 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
261 
262   while (blkCnt > 0U)
263   {
264     a0 = *pSrcA++;
265     b0 = *pSrcA++;
266     c0 = *pSrcB++;
267     d0 = *pSrcB++;
268 
269     real_sum += a0 * c0;
270     imag_sum += a0 * d0;
271     real_sum -= b0 * d0;
272     imag_sum += b0 * c0;
273 
274     /* Decrement loop counter */
275     blkCnt--;
276   }
277 
278   /* Store real and imaginary result in destination buffer. */
279   *realResult = real_sum;
280   *imagResult = imag_sum;
281 }
282 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
283 
284 /**
285   @} end of cmplx_dot_prod group
286  */
287 
288 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */