1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_barycenter_f16.c
4  * Description:  Barycenter
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/support_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 #include <limits.h>
34 #include <math.h>
35 
36 /**
37   @ingroup groupSupport
38  */
39 
40 /**
41   @defgroup barycenter Barycenter
42 
43   Barycenter of weighted vectors
44  */
45 
46 /**
47   @addtogroup barycenter
48   @{
49  */
50 
51 
52 /**
53  * @brief Barycenter
54  *
55  *
56  * @param[in]    *in         List of vectors
57  * @param[in]    *weights    Weights of the vectors
58  * @param[out]   *out        Barycenter
59  * @param[in]    nbVectors   Number of vectors
60  * @param[in]    vecDim      Dimension of space (vector dimension)
61  *
62  */
63 
64 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
65 
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)66 void arm_barycenter_f16(const float16_t *in,
67   const float16_t *weights,
68   float16_t *out,
69   uint32_t nbVectors,
70   uint32_t vecDim)
71 {
72     const float16_t *pIn, *pW;
73     const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
74     float16_t      *pOut;
75     uint32_t        blkCntVector, blkCntSample;
76     float16_t       accum, w;
77 
78     blkCntVector = nbVectors;
79     blkCntSample = vecDim;
80 
81     accum = 0.0f;
82 
83     pW = weights;
84     pIn = in;
85 
86 
87     arm_fill_f16(0.0f, out, vecDim);
88 
89 
90     /* Sum */
91     pIn1 = pIn;
92     pIn2 = pIn1 + vecDim;
93     pIn3 = pIn2 + vecDim;
94     pIn4 = pIn3 + vecDim;
95 
96     blkCntVector = nbVectors >> 2;
97     while (blkCntVector > 0)
98     {
99         f16x8_t         outV, inV1, inV2, inV3, inV4;
100         float16_t       w1, w2, w3, w4;
101 
102         pOut = out;
103         w1 = *pW++;
104         w2 = *pW++;
105         w3 = *pW++;
106         w4 = *pW++;
107         accum += (_Float16)w1 + (_Float16)w2 + (_Float16)w3 + (_Float16)w4;
108 
109         blkCntSample = vecDim >> 3;
110         while (blkCntSample > 0) {
111             outV = vld1q((const float16_t *) pOut);
112             inV1 = vld1q(pIn1);
113             inV2 = vld1q(pIn2);
114             inV3 = vld1q(pIn3);
115             inV4 = vld1q(pIn4);
116             outV = vfmaq(outV, inV1, w1);
117             outV = vfmaq(outV, inV2, w2);
118             outV = vfmaq(outV, inV3, w3);
119             outV = vfmaq(outV, inV4, w4);
120             vst1q(pOut, outV);
121 
122             pOut += 8;
123             pIn1 += 8;
124             pIn2 += 8;
125             pIn3 += 8;
126             pIn4 += 8;
127 
128             blkCntSample--;
129         }
130 
131         blkCntSample = vecDim & 7;
132         while (blkCntSample > 0) {
133             *pOut = (_Float16)*pOut + (_Float16)*pIn1++ * (_Float16)w1;
134             *pOut = (_Float16)*pOut + (_Float16)*pIn2++ * (_Float16)w2;
135             *pOut = (_Float16)*pOut + (_Float16)*pIn3++ * (_Float16)w3;
136             *pOut = (_Float16)*pOut + (_Float16)*pIn4++ * (_Float16)w4;
137             pOut++;
138             blkCntSample--;
139         }
140 
141         pIn1 += 3 * vecDim;
142         pIn2 += 3 * vecDim;
143         pIn3 += 3 * vecDim;
144         pIn4 += 3 * vecDim;
145 
146         blkCntVector--;
147     }
148 
149     pIn = pIn1;
150 
151     blkCntVector = nbVectors & 3;
152     while (blkCntVector > 0)
153     {
154         f16x8_t         inV, outV;
155 
156         pOut = out;
157         w = *pW++;
158         accum += (_Float16)w;
159 
160         blkCntSample = vecDim >> 3;
161         while (blkCntSample > 0)
162         {
163             outV = vld1q_f16(pOut);
164             inV = vld1q_f16(pIn);
165             outV = vfmaq(outV, inV, w);
166             vst1q_f16(pOut, outV);
167             pOut += 8;
168             pIn += 8;
169 
170             blkCntSample--;
171         }
172 
173         blkCntSample = vecDim & 7;
174         while (blkCntSample > 0)
175         {
176             *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
177             pOut++;
178             blkCntSample--;
179         }
180 
181         blkCntVector--;
182     }
183 
184     /* Normalize */
185     pOut = out;
186     accum = 1.0f16 / (_Float16)accum;
187 
188     blkCntSample = vecDim >> 3;
189     while (blkCntSample > 0)
190     {
191         f16x8_t         tmp;
192 
193         tmp = vld1q((const float16_t *) pOut);
194         tmp = vmulq(tmp, accum);
195         vst1q(pOut, tmp);
196         pOut += 8;
197         blkCntSample--;
198     }
199 
200     blkCntSample = vecDim & 7;
201     while (blkCntSample > 0)
202     {
203         *pOut = (_Float16)*pOut * (_Float16)accum;
204         pOut++;
205         blkCntSample--;
206     }
207 }
208 #else
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)209 void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
210 {
211 
212    const float16_t *pIn,*pW;
213    float16_t *pOut;
214    uint32_t blkCntVector,blkCntSample;
215    float16_t accum, w;
216 
217    blkCntVector = nbVectors;
218    blkCntSample = vecDim;
219 
220    accum = 0.0f16;
221 
222    pW = weights;
223    pIn = in;
224 
225    /* Set counters to 0 */
226    blkCntSample = vecDim;
227    pOut = out;
228 
229    while(blkCntSample > 0)
230    {
231          *pOut = 0.0f16;
232          pOut++;
233          blkCntSample--;
234    }
235 
236    /* Sum */
237    while(blkCntVector > 0)
238    {
239       pOut = out;
240       w = *pW++;
241       accum += (_Float16)w;
242 
243       blkCntSample = vecDim;
244       while(blkCntSample > 0)
245       {
246           *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
247           pOut++;
248           blkCntSample--;
249       }
250 
251       blkCntVector--;
252    }
253 
254    /* Normalize */
255    blkCntSample = vecDim;
256    pOut = out;
257 
258    while(blkCntSample > 0)
259    {
260          *pOut = (_Float16)*pOut / (_Float16)accum;
261          pOut++;
262          blkCntSample--;
263    }
264 
265 }
266 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
267 
268 /**
269  * @} end of barycenter group
270  */
271 
272 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
273 
274