1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_barycenter_f16.c
4 * Description: Barycenter
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/support_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33 #include <limits.h>
34 #include <math.h>
35
36 /**
37 @ingroup groupSupport
38 */
39
40 /**
41 @defgroup barycenter Barycenter
42
43 Barycenter of weighted vectors
44 */
45
46 /**
47 @addtogroup barycenter
48 @{
49 */
50
51
52 /**
53 * @brief Barycenter
54 *
55 *
56 * @param[in] *in List of vectors
57 * @param[in] *weights Weights of the vectors
58 * @param[out] *out Barycenter
59 * @param[in] nbVectors Number of vectors
60 * @param[in] vecDim Dimension of space (vector dimension)
61 * @return None
62 *
63 */
64
65 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
66
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)67 void arm_barycenter_f16(const float16_t *in,
68 const float16_t *weights,
69 float16_t *out,
70 uint32_t nbVectors,
71 uint32_t vecDim)
72 {
73 const float16_t *pIn, *pW;
74 const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
75 float16_t *pOut;
76 uint32_t blkCntVector, blkCntSample;
77 float16_t accum, w;
78
79 blkCntVector = nbVectors;
80 blkCntSample = vecDim;
81
82 accum = 0.0f;
83
84 pW = weights;
85 pIn = in;
86
87
88 arm_fill_f16(0.0f, out, vecDim);
89
90
91 /* Sum */
92 pIn1 = pIn;
93 pIn2 = pIn1 + vecDim;
94 pIn3 = pIn2 + vecDim;
95 pIn4 = pIn3 + vecDim;
96
97 blkCntVector = nbVectors >> 2;
98 while (blkCntVector > 0)
99 {
100 f16x8_t outV, inV1, inV2, inV3, inV4;
101 float16_t w1, w2, w3, w4;
102
103 pOut = out;
104 w1 = *pW++;
105 w2 = *pW++;
106 w3 = *pW++;
107 w4 = *pW++;
108 accum += w1 + w2 + w3 + w4;
109
110 blkCntSample = vecDim >> 3;
111 while (blkCntSample > 0) {
112 outV = vld1q((const float16_t *) pOut);
113 inV1 = vld1q(pIn1);
114 inV2 = vld1q(pIn2);
115 inV3 = vld1q(pIn3);
116 inV4 = vld1q(pIn4);
117 outV = vfmaq(outV, inV1, w1);
118 outV = vfmaq(outV, inV2, w2);
119 outV = vfmaq(outV, inV3, w3);
120 outV = vfmaq(outV, inV4, w4);
121 vst1q(pOut, outV);
122
123 pOut += 8;
124 pIn1 += 8;
125 pIn2 += 8;
126 pIn3 += 8;
127 pIn4 += 8;
128
129 blkCntSample--;
130 }
131
132 blkCntSample = vecDim & 7;
133 while (blkCntSample > 0) {
134 *pOut = *pOut + *pIn1++ * w1;
135 *pOut = *pOut + *pIn2++ * w2;
136 *pOut = *pOut + *pIn3++ * w3;
137 *pOut = *pOut + *pIn4++ * w4;
138 pOut++;
139 blkCntSample--;
140 }
141
142 pIn1 += 3 * vecDim;
143 pIn2 += 3 * vecDim;
144 pIn3 += 3 * vecDim;
145 pIn4 += 3 * vecDim;
146
147 blkCntVector--;
148 }
149
150 pIn = pIn1;
151
152 blkCntVector = nbVectors & 3;
153 while (blkCntVector > 0)
154 {
155 f16x8_t inV, outV;
156
157 pOut = out;
158 w = *pW++;
159 accum += w;
160
161 blkCntSample = vecDim >> 3;
162 while (blkCntSample > 0)
163 {
164 outV = vld1q_f16(pOut);
165 inV = vld1q_f16(pIn);
166 outV = vfmaq(outV, inV, w);
167 vst1q_f16(pOut, outV);
168 pOut += 8;
169 pIn += 8;
170
171 blkCntSample--;
172 }
173
174 blkCntSample = vecDim & 7;
175 while (blkCntSample > 0)
176 {
177 *pOut = *pOut + *pIn++ * w;
178 pOut++;
179 blkCntSample--;
180 }
181
182 blkCntVector--;
183 }
184
185 /* Normalize */
186 pOut = out;
187 accum = 1.0f / accum;
188
189 blkCntSample = vecDim >> 3;
190 while (blkCntSample > 0)
191 {
192 f16x8_t tmp;
193
194 tmp = vld1q((const float16_t *) pOut);
195 tmp = vmulq(tmp, accum);
196 vst1q(pOut, tmp);
197 pOut += 8;
198 blkCntSample--;
199 }
200
201 blkCntSample = vecDim & 7;
202 while (blkCntSample > 0)
203 {
204 *pOut = *pOut * accum;
205 pOut++;
206 blkCntSample--;
207 }
208 }
209 #else
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)210 void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
211 {
212
213 const float16_t *pIn,*pW;
214 float16_t *pOut;
215 uint32_t blkCntVector,blkCntSample;
216 float16_t accum, w;
217
218 blkCntVector = nbVectors;
219 blkCntSample = vecDim;
220
221 accum = 0.0f;
222
223 pW = weights;
224 pIn = in;
225
226 /* Set counters to 0 */
227 blkCntSample = vecDim;
228 pOut = out;
229
230 while(blkCntSample > 0)
231 {
232 *pOut = 0.0f;
233 pOut++;
234 blkCntSample--;
235 }
236
237 /* Sum */
238 while(blkCntVector > 0)
239 {
240 pOut = out;
241 w = *pW++;
242 accum += w;
243
244 blkCntSample = vecDim;
245 while(blkCntSample > 0)
246 {
247 *pOut = *pOut + *pIn++ * w;
248 pOut++;
249 blkCntSample--;
250 }
251
252 blkCntVector--;
253 }
254
255 /* Normalize */
256 blkCntSample = vecDim;
257 pOut = out;
258
259 while(blkCntSample > 0)
260 {
261 *pOut = *pOut / accum;
262 pOut++;
263 blkCntSample--;
264 }
265
266 }
267 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
268
269 /**
270 * @} end of barycenter group
271 */
272
273 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
274
275