1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_barycenter_f16.c
4 * Description: Barycenter
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/support_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33 #include <limits.h>
34 #include <math.h>
35
36 /**
37 @ingroup groupSupport
38 */
39
40 /**
41 @defgroup barycenter Barycenter
42
43 Barycenter of weighted vectors
44 */
45
46 /**
47 @addtogroup barycenter
48 @{
49 */
50
51
52 /**
53 * @brief Barycenter
54 *
55 *
56 * @param[in] *in List of vectors
57 * @param[in] *weights Weights of the vectors
58 * @param[out] *out Barycenter
59 * @param[in] nbVectors Number of vectors
60 * @param[in] vecDim Dimension of space (vector dimension)
61 *
62 */
63
64 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
65
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)66 void arm_barycenter_f16(const float16_t *in,
67 const float16_t *weights,
68 float16_t *out,
69 uint32_t nbVectors,
70 uint32_t vecDim)
71 {
72 const float16_t *pIn, *pW;
73 const float16_t *pIn1, *pIn2, *pIn3, *pIn4;
74 float16_t *pOut;
75 uint32_t blkCntVector, blkCntSample;
76 float16_t accum, w;
77
78 blkCntVector = nbVectors;
79 blkCntSample = vecDim;
80
81 accum = 0.0f;
82
83 pW = weights;
84 pIn = in;
85
86
87 arm_fill_f16(0.0f, out, vecDim);
88
89
90 /* Sum */
91 pIn1 = pIn;
92 pIn2 = pIn1 + vecDim;
93 pIn3 = pIn2 + vecDim;
94 pIn4 = pIn3 + vecDim;
95
96 blkCntVector = nbVectors >> 2;
97 while (blkCntVector > 0)
98 {
99 f16x8_t outV, inV1, inV2, inV3, inV4;
100 float16_t w1, w2, w3, w4;
101
102 pOut = out;
103 w1 = *pW++;
104 w2 = *pW++;
105 w3 = *pW++;
106 w4 = *pW++;
107 accum += (_Float16)w1 + (_Float16)w2 + (_Float16)w3 + (_Float16)w4;
108
109 blkCntSample = vecDim >> 3;
110 while (blkCntSample > 0) {
111 outV = vld1q((const float16_t *) pOut);
112 inV1 = vld1q(pIn1);
113 inV2 = vld1q(pIn2);
114 inV3 = vld1q(pIn3);
115 inV4 = vld1q(pIn4);
116 outV = vfmaq(outV, inV1, w1);
117 outV = vfmaq(outV, inV2, w2);
118 outV = vfmaq(outV, inV3, w3);
119 outV = vfmaq(outV, inV4, w4);
120 vst1q(pOut, outV);
121
122 pOut += 8;
123 pIn1 += 8;
124 pIn2 += 8;
125 pIn3 += 8;
126 pIn4 += 8;
127
128 blkCntSample--;
129 }
130
131 blkCntSample = vecDim & 7;
132 while (blkCntSample > 0) {
133 *pOut = (_Float16)*pOut + (_Float16)*pIn1++ * (_Float16)w1;
134 *pOut = (_Float16)*pOut + (_Float16)*pIn2++ * (_Float16)w2;
135 *pOut = (_Float16)*pOut + (_Float16)*pIn3++ * (_Float16)w3;
136 *pOut = (_Float16)*pOut + (_Float16)*pIn4++ * (_Float16)w4;
137 pOut++;
138 blkCntSample--;
139 }
140
141 pIn1 += 3 * vecDim;
142 pIn2 += 3 * vecDim;
143 pIn3 += 3 * vecDim;
144 pIn4 += 3 * vecDim;
145
146 blkCntVector--;
147 }
148
149 pIn = pIn1;
150
151 blkCntVector = nbVectors & 3;
152 while (blkCntVector > 0)
153 {
154 f16x8_t inV, outV;
155
156 pOut = out;
157 w = *pW++;
158 accum += (_Float16)w;
159
160 blkCntSample = vecDim >> 3;
161 while (blkCntSample > 0)
162 {
163 outV = vld1q_f16(pOut);
164 inV = vld1q_f16(pIn);
165 outV = vfmaq(outV, inV, w);
166 vst1q_f16(pOut, outV);
167 pOut += 8;
168 pIn += 8;
169
170 blkCntSample--;
171 }
172
173 blkCntSample = vecDim & 7;
174 while (blkCntSample > 0)
175 {
176 *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
177 pOut++;
178 blkCntSample--;
179 }
180
181 blkCntVector--;
182 }
183
184 /* Normalize */
185 pOut = out;
186 accum = 1.0f16 / (_Float16)accum;
187
188 blkCntSample = vecDim >> 3;
189 while (blkCntSample > 0)
190 {
191 f16x8_t tmp;
192
193 tmp = vld1q((const float16_t *) pOut);
194 tmp = vmulq(tmp, accum);
195 vst1q(pOut, tmp);
196 pOut += 8;
197 blkCntSample--;
198 }
199
200 blkCntSample = vecDim & 7;
201 while (blkCntSample > 0)
202 {
203 *pOut = (_Float16)*pOut * (_Float16)accum;
204 pOut++;
205 blkCntSample--;
206 }
207 }
208 #else
arm_barycenter_f16(const float16_t * in,const float16_t * weights,float16_t * out,uint32_t nbVectors,uint32_t vecDim)209 void arm_barycenter_f16(const float16_t *in, const float16_t *weights, float16_t *out, uint32_t nbVectors,uint32_t vecDim)
210 {
211
212 const float16_t *pIn,*pW;
213 float16_t *pOut;
214 uint32_t blkCntVector,blkCntSample;
215 float16_t accum, w;
216
217 blkCntVector = nbVectors;
218 blkCntSample = vecDim;
219
220 accum = 0.0f16;
221
222 pW = weights;
223 pIn = in;
224
225 /* Set counters to 0 */
226 blkCntSample = vecDim;
227 pOut = out;
228
229 while(blkCntSample > 0)
230 {
231 *pOut = 0.0f16;
232 pOut++;
233 blkCntSample--;
234 }
235
236 /* Sum */
237 while(blkCntVector > 0)
238 {
239 pOut = out;
240 w = *pW++;
241 accum += (_Float16)w;
242
243 blkCntSample = vecDim;
244 while(blkCntSample > 0)
245 {
246 *pOut = (_Float16)*pOut + (_Float16)*pIn++ * (_Float16)w;
247 pOut++;
248 blkCntSample--;
249 }
250
251 blkCntVector--;
252 }
253
254 /* Normalize */
255 blkCntSample = vecDim;
256 pOut = out;
257
258 while(blkCntSample > 0)
259 {
260 *pOut = (_Float16)*pOut / (_Float16)accum;
261 pOut++;
262 blkCntSample--;
263 }
264
265 }
266 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
267
268 /**
269 * @} end of barycenter group
270 */
271
272 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
273
274