1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_barycenter_f32.c
4  * Description:  Barycenter
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/support_functions.h"
30 #include <limits.h>
31 #include <math.h>
32 
33 
34 /**
35   @ingroup barycenter
36  */
37 
38 
39 /**
40  * @brief Barycenter
41  *
42  *
43  * @param[in]    *in         List of vectors
44  * @param[in]    *weights    Weights of the vectors
45  * @param[out]   *out        Barycenter
46  * @param[in]    nbVectors   Number of vectors
47  * @param[in]    vecDim      Dimension of space (vector dimension)
48  * @return       None
49  *
50  */
51 
52 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
arm_barycenter_f32(const float32_t * in,const float32_t * weights,float32_t * out,uint32_t nbVectors,uint32_t vecDim)53 void arm_barycenter_f32(const float32_t *in,
54   const float32_t *weights,
55   float32_t *out,
56   uint32_t nbVectors,
57   uint32_t vecDim)
58 {
59     const float32_t *pIn, *pW;
60     const float32_t *pIn1, *pIn2, *pIn3, *pIn4;
61     float32_t      *pOut;
62     uint32_t        blkCntVector, blkCntSample;
63     float32_t       accum, w;
64 
65     blkCntVector = nbVectors;
66     blkCntSample = vecDim;
67 
68     accum = 0.0f;
69 
70     pW = weights;
71     pIn = in;
72 
73 
74     arm_fill_f32(0.0f, out, vecDim);
75 
76 
77     /* Sum */
78     pIn1 = pIn;
79     pIn2 = pIn1 + vecDim;
80     pIn3 = pIn2 + vecDim;
81     pIn4 = pIn3 + vecDim;
82 
83     blkCntVector = nbVectors >> 2;
84     while (blkCntVector > 0)
85     {
86         f32x4_t         outV, inV1, inV2, inV3, inV4;
87         float32_t       w1, w2, w3, w4;
88 
89         pOut = out;
90         w1 = *pW++;
91         w2 = *pW++;
92         w3 = *pW++;
93         w4 = *pW++;
94         accum += w1 + w2 + w3 + w4;
95 
96         blkCntSample = vecDim >> 2;
97         while (blkCntSample > 0) {
98             outV = vld1q((const float32_t *) pOut);
99             inV1 = vld1q(pIn1);
100             inV2 = vld1q(pIn2);
101             inV3 = vld1q(pIn3);
102             inV4 = vld1q(pIn4);
103             outV = vfmaq(outV, inV1, w1);
104             outV = vfmaq(outV, inV2, w2);
105             outV = vfmaq(outV, inV3, w3);
106             outV = vfmaq(outV, inV4, w4);
107             vst1q(pOut, outV);
108 
109             pOut += 4;
110             pIn1 += 4;
111             pIn2 += 4;
112             pIn3 += 4;
113             pIn4 += 4;
114 
115             blkCntSample--;
116         }
117 
118         blkCntSample = vecDim & 3;
119         while (blkCntSample > 0) {
120             *pOut = *pOut + *pIn1++ * w1;
121             *pOut = *pOut + *pIn2++ * w2;
122             *pOut = *pOut + *pIn3++ * w3;
123             *pOut = *pOut + *pIn4++ * w4;
124             pOut++;
125             blkCntSample--;
126         }
127 
128         pIn1 += 3 * vecDim;
129         pIn2 += 3 * vecDim;
130         pIn3 += 3 * vecDim;
131         pIn4 += 3 * vecDim;
132 
133         blkCntVector--;
134     }
135 
136     pIn = pIn1;
137 
138     blkCntVector = nbVectors & 3;
139     while (blkCntVector > 0)
140     {
141         f32x4_t         inV, outV;
142 
143         pOut = out;
144         w = *pW++;
145         accum += w;
146 
147         blkCntSample = vecDim >> 2;
148         while (blkCntSample > 0)
149         {
150             outV = vld1q_f32(pOut);
151             inV = vld1q_f32(pIn);
152             outV = vfmaq(outV, inV, w);
153             vst1q_f32(pOut, outV);
154             pOut += 4;
155             pIn += 4;
156 
157             blkCntSample--;
158         }
159 
160         blkCntSample = vecDim & 3;
161         while (blkCntSample > 0)
162         {
163             *pOut = *pOut + *pIn++ * w;
164             pOut++;
165             blkCntSample--;
166         }
167 
168         blkCntVector--;
169     }
170 
171     /* Normalize */
172     pOut = out;
173     accum = 1.0f / accum;
174 
175     blkCntSample = vecDim >> 2;
176     while (blkCntSample > 0)
177     {
178         f32x4_t         tmp;
179 
180         tmp = vld1q((const float32_t *) pOut);
181         tmp = vmulq(tmp, accum);
182         vst1q(pOut, tmp);
183         pOut += 4;
184         blkCntSample--;
185     }
186 
187     blkCntSample = vecDim & 3;
188     while (blkCntSample > 0)
189     {
190         *pOut = *pOut * accum;
191         pOut++;
192         blkCntSample--;
193     }
194 }
195 #else
196 #if defined(ARM_MATH_NEON)
197 
198 #include "NEMath.h"
arm_barycenter_f32(const float32_t * in,const float32_t * weights,float32_t * out,uint32_t nbVectors,uint32_t vecDim)199 void arm_barycenter_f32(const float32_t *in, const float32_t *weights, float32_t *out, uint32_t nbVectors,uint32_t vecDim)
200 {
201 
202    const float32_t *pIn,*pW, *pIn1, *pIn2, *pIn3, *pIn4;
203    float32_t *pOut;
204    uint32_t blkCntVector,blkCntSample;
205    float32_t accum, w,w1,w2,w3,w4;
206 
207    float32x4_t tmp, inV,outV, inV1, inV2, inV3, inV4;
208 
209    blkCntVector = nbVectors;
210    blkCntSample = vecDim;
211 
212    accum = 0.0f;
213 
214    pW = weights;
215    pIn = in;
216 
217    /* Set counters to 0 */
218    tmp = vdupq_n_f32(0.0f);
219    pOut = out;
220 
221    blkCntSample = vecDim >> 2;
222    while(blkCntSample > 0)
223    {
224          vst1q_f32(pOut, tmp);
225          pOut += 4;
226          blkCntSample--;
227    }
228 
229    blkCntSample = vecDim & 3;
230    while(blkCntSample > 0)
231    {
232          *pOut = 0.0f;
233          pOut++;
234          blkCntSample--;
235    }
236 
237    /* Sum */
238 
239    pIn1 = pIn;
240    pIn2 = pIn1 + vecDim;
241    pIn3 = pIn2 + vecDim;
242    pIn4 = pIn3 + vecDim;
243 
244    blkCntVector = nbVectors >> 2;
245    while(blkCntVector > 0)
246    {
247       pOut = out;
248       w1 = *pW++;
249       w2 = *pW++;
250       w3 = *pW++;
251       w4 = *pW++;
252       accum += w1 + w2 + w3 + w4;
253 
254       blkCntSample = vecDim >> 2;
255       while(blkCntSample > 0)
256       {
257           outV = vld1q_f32(pOut);
258           inV1 = vld1q_f32(pIn1);
259           inV2 = vld1q_f32(pIn2);
260           inV3 = vld1q_f32(pIn3);
261           inV4 = vld1q_f32(pIn4);
262           outV = vmlaq_n_f32(outV,inV1,w1);
263           outV = vmlaq_n_f32(outV,inV2,w2);
264           outV = vmlaq_n_f32(outV,inV3,w3);
265           outV = vmlaq_n_f32(outV,inV4,w4);
266           vst1q_f32(pOut, outV);
267           pOut += 4;
268           pIn1 += 4;
269           pIn2 += 4;
270           pIn3 += 4;
271           pIn4 += 4;
272 
273           blkCntSample--;
274       }
275 
276       blkCntSample = vecDim & 3;
277       while(blkCntSample > 0)
278       {
279           *pOut = *pOut + *pIn1++ * w1;
280           *pOut = *pOut + *pIn2++ * w2;
281           *pOut = *pOut + *pIn3++ * w3;
282           *pOut = *pOut + *pIn4++ * w4;
283           pOut++;
284           blkCntSample--;
285       }
286 
287       pIn1 += 3*vecDim;
288       pIn2 += 3*vecDim;
289       pIn3 += 3*vecDim;
290       pIn4 += 3*vecDim;
291 
292       blkCntVector--;
293    }
294 
295    pIn = pIn1;
296 
297    blkCntVector = nbVectors & 3;
298    while(blkCntVector > 0)
299    {
300       pOut = out;
301       w = *pW++;
302       accum += w;
303 
304       blkCntSample = vecDim >> 2;
305       while(blkCntSample > 0)
306       {
307           outV = vld1q_f32(pOut);
308           inV = vld1q_f32(pIn);
309           outV = vmlaq_n_f32(outV,inV,w);
310           vst1q_f32(pOut, outV);
311           pOut += 4;
312           pIn += 4;
313 
314           blkCntSample--;
315       }
316 
317       blkCntSample = vecDim & 3;
318       while(blkCntSample > 0)
319       {
320           *pOut = *pOut + *pIn++ * w;
321           pOut++;
322           blkCntSample--;
323       }
324 
325       blkCntVector--;
326    }
327 
328    /* Normalize */
329    pOut = out;
330    accum = 1.0f / accum;
331 
332    blkCntSample = vecDim >> 2;
333    while(blkCntSample > 0)
334    {
335          tmp = vld1q_f32(pOut);
336          tmp = vmulq_n_f32(tmp,accum);
337          vst1q_f32(pOut, tmp);
338          pOut += 4;
339          blkCntSample--;
340    }
341 
342    blkCntSample = vecDim & 3;
343    while(blkCntSample > 0)
344    {
345          *pOut = *pOut * accum;
346          pOut++;
347          blkCntSample--;
348    }
349 
350 }
351 #else
arm_barycenter_f32(const float32_t * in,const float32_t * weights,float32_t * out,uint32_t nbVectors,uint32_t vecDim)352 void arm_barycenter_f32(const float32_t *in, const float32_t *weights, float32_t *out, uint32_t nbVectors,uint32_t vecDim)
353 {
354 
355    const float32_t *pIn,*pW;
356    float32_t *pOut;
357    uint32_t blkCntVector,blkCntSample;
358    float32_t accum, w;
359 
360    blkCntVector = nbVectors;
361    blkCntSample = vecDim;
362 
363    accum = 0.0f;
364 
365    pW = weights;
366    pIn = in;
367 
368    /* Set counters to 0 */
369    blkCntSample = vecDim;
370    pOut = out;
371 
372    while(blkCntSample > 0)
373    {
374          *pOut = 0.0f;
375          pOut++;
376          blkCntSample--;
377    }
378 
379    /* Sum */
380    while(blkCntVector > 0)
381    {
382       pOut = out;
383       w = *pW++;
384       accum += w;
385 
386       blkCntSample = vecDim;
387       while(blkCntSample > 0)
388       {
389           *pOut = *pOut + *pIn++ * w;
390           pOut++;
391           blkCntSample--;
392       }
393 
394       blkCntVector--;
395    }
396 
397    /* Normalize */
398    blkCntSample = vecDim;
399    pOut = out;
400 
401    while(blkCntSample > 0)
402    {
403          *pOut = *pOut / accum;
404          pOut++;
405          blkCntSample--;
406    }
407 
408 }
409 #endif
410 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
411 
412 /**
413  * @} end of barycenter group
414  */
415