1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_barycenter_f16.c
4  * Description:  Barycenter
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/support_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 #include <limits.h>
34 #include <math.h>
35 
36 /**
37   @ingroup groupSupport
38  */
39 
40 /**
41   @defgroup barycenter Barycenter
42 
43   Barycenter of weighted vectors
44  */
45 
46 /**
47   @addtogroup barycenter
48   @{
49  */
50 
51 
52 /**
53  * @brief Barycenter
54  *
55  *
56  * @param[in]    *in         List of vectors
57  * @param[in]    *weights    Weights of the vectors
58  * @param[out]   *out        Barycenter
59  * @param[in]    nbVectors   Number of vectors
60  * @param[in]    vecDim      Dimension of space (vector dimension)
61  * @return       None
62  *
63  */
64 
65 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
66 
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)67 void arm_barycenter_f16(const float16_t *in,
68   const float16_t *weights,
69   float16_t *out,
70   uint32_t nbVectors,
71   uint32_t vecDim)
72 {
73     const float16_t *pIn, *pW;
74     const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
75     float16_t      *pOut;
76     uint32_t        blkCntVector, blkCntSample;
77     float16_t       accum, w;
78 
79     blkCntVector = nbVectors;
80     blkCntSample = vecDim;
81 
82     accum = 0.0f;
83 
84     pW = weights;
85     pIn = in;
86 
87 
88     arm_fill_f16(0.0f, out, vecDim);
89 
90 
91     /* Sum */
92     pIn1 = pIn;
93     pIn2 = pIn1 + vecDim;
94     pIn3 = pIn2 + vecDim;
95     pIn4 = pIn3 + vecDim;
96 
97     blkCntVector = nbVectors >> 2;
98     while (blkCntVector > 0)
99     {
100         f16x8_t         outV, inV1, inV2, inV3, inV4;
101         float16_t       w1, w2, w3, w4;
102 
103         pOut = out;
104         w1 = *pW++;
105         w2 = *pW++;
106         w3 = *pW++;
107         w4 = *pW++;
108         accum += w1 + w2 + w3 + w4;
109 
110         blkCntSample = vecDim >> 3;
111         while (blkCntSample > 0) {
112             outV = vld1q((const float16_t *) pOut);
113             inV1 = vld1q(pIn1);
114             inV2 = vld1q(pIn2);
115             inV3 = vld1q(pIn3);
116             inV4 = vld1q(pIn4);
117             outV = vfmaq(outV, inV1, w1);
118             outV = vfmaq(outV, inV2, w2);
119             outV = vfmaq(outV, inV3, w3);
120             outV = vfmaq(outV, inV4, w4);
121             vst1q(pOut, outV);
122 
123             pOut += 8;
124             pIn1 += 8;
125             pIn2 += 8;
126             pIn3 += 8;
127             pIn4 += 8;
128 
129             blkCntSample--;
130         }
131 
132         blkCntSample = vecDim & 7;
133         while (blkCntSample > 0) {
134             *pOut = *pOut + *pIn1++ * w1;
135             *pOut = *pOut + *pIn2++ * w2;
136             *pOut = *pOut + *pIn3++ * w3;
137             *pOut = *pOut + *pIn4++ * w4;
138             pOut++;
139             blkCntSample--;
140         }
141 
142         pIn1 += 3 * vecDim;
143         pIn2 += 3 * vecDim;
144         pIn3 += 3 * vecDim;
145         pIn4 += 3 * vecDim;
146 
147         blkCntVector--;
148     }
149 
150     pIn = pIn1;
151 
152     blkCntVector = nbVectors & 3;
153     while (blkCntVector > 0)
154     {
155         f16x8_t         inV, outV;
156 
157         pOut = out;
158         w = *pW++;
159         accum += w;
160 
161         blkCntSample = vecDim >> 3;
162         while (blkCntSample > 0)
163         {
164             outV = vld1q_f16(pOut);
165             inV = vld1q_f16(pIn);
166             outV = vfmaq(outV, inV, w);
167             vst1q_f16(pOut, outV);
168             pOut += 8;
169             pIn += 8;
170 
171             blkCntSample--;
172         }
173 
174         blkCntSample = vecDim & 7;
175         while (blkCntSample > 0)
176         {
177             *pOut = *pOut + *pIn++ * w;
178             pOut++;
179             blkCntSample--;
180         }
181 
182         blkCntVector--;
183     }
184 
185     /* Normalize */
186     pOut = out;
187     accum = 1.0f / accum;
188 
189     blkCntSample = vecDim >> 3;
190     while (blkCntSample > 0)
191     {
192         f16x8_t         tmp;
193 
194         tmp = vld1q((const float16_t *) pOut);
195         tmp = vmulq(tmp, accum);
196         vst1q(pOut, tmp);
197         pOut += 8;
198         blkCntSample--;
199     }
200 
201     blkCntSample = vecDim & 7;
202     while (blkCntSample > 0)
203     {
204         *pOut = *pOut * accum;
205         pOut++;
206         blkCntSample--;
207     }
208 }
209 #else
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)210 void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
211 {
212 
213    const float16_t *pIn,*pW;
214    float16_t *pOut;
215    uint32_t blkCntVector,blkCntSample;
216    float16_t accum, w;
217 
218    blkCntVector = nbVectors;
219    blkCntSample = vecDim;
220 
221    accum = 0.0f;
222 
223    pW = weights;
224    pIn = in;
225 
226    /* Set counters to 0 */
227    blkCntSample = vecDim;
228    pOut = out;
229 
230    while(blkCntSample > 0)
231    {
232          *pOut = 0.0f;
233          pOut++;
234          blkCntSample--;
235    }
236 
237    /* Sum */
238    while(blkCntVector > 0)
239    {
240       pOut = out;
241       w = *pW++;
242       accum += w;
243 
244       blkCntSample = vecDim;
245       while(blkCntSample > 0)
246       {
247           *pOut = *pOut + *pIn++ * w;
248           pOut++;
249           blkCntSample--;
250       }
251 
252       blkCntVector--;
253    }
254 
255    /* Normalize */
256    blkCntSample = vecDim;
257    pOut = out;
258 
259    while(blkCntSample > 0)
260    {
261          *pOut = *pOut / accum;
262          pOut++;
263          blkCntSample--;
264    }
265 
266 }
267 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
268 
269 /**
270  * @} end of barycenter group
271  */
272 
273 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
274 
275