1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_logsumexp_f32.c
4 * Description: LogSumExp
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/statistics_functions.h"
30 #include <limits.h>
31 #include <math.h>
32
33
34 /**
35 * @addtogroup LogSumExp
36 * @{
37 */
38
39
40 /**
41 * @brief Computation of the LogSumExp
42 *
43 * In probabilistic computations, the dynamic of the probability values can be very
44 * wide because they come from gaussian functions.
45 * To avoid underflow and overflow issues, the values are represented by their log.
46 * In this representation, multiplying the original exp values is easy : their logs are added.
47 * But adding the original exp values is requiring some special handling and it is the
48 * goal of the LogSumExp function.
49 *
50 * If the values are x1...xn, the function is computing:
51 *
52 * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
53 * rounding issues are minimised.
54 *
55 * The max xm of the values is extracted and the function is computing:
56 * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
57 *
58 * @param[in] *in Pointer to an array of input values.
59 * @param[in] blockSize Number of samples in the input array.
60 * @return LogSumExp
61 *
62 */
63
64 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
65
66 #include "arm_helium_utils.h"
67 #include "arm_vec_math.h"
68
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)69 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
70 {
71 float32_t maxVal;
72 const float32_t *pIn;
73 int32_t blkCnt;
74 float32_t accum=0.0f;
75 float32_t tmp;
76
77
78 arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal);
79
80
81 blkCnt = blockSize;
82 pIn = in;
83
84
85 f32x4_t vSum = vdupq_n_f32(0.0f);
86 blkCnt = blockSize >> 2;
87 while(blkCnt > 0)
88 {
89 f32x4_t vecIn = vld1q(pIn);
90 f32x4_t vecExp;
91
92 vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal));
93
94 vSum = vaddq_f32(vSum, vecExp);
95
96 /*
97 * Decrement the blockSize loop counter
98 * Advance vector source and destination pointers
99 */
100 pIn += 4;
101 blkCnt --;
102 }
103
104 /* sum + log */
105 accum = vecAddAcrossF32Mve(vSum);
106
107 blkCnt = blockSize & 0x3;
108 while(blkCnt > 0)
109 {
110 tmp = *pIn++;
111 accum += expf(tmp - maxVal);
112 blkCnt--;
113
114 }
115
116 accum = maxVal + log(accum);
117
118 return (accum);
119 }
120
121 #else
122 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
123
124 #include "NEMath.h"
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)125 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
126 {
127 float32_t maxVal;
128 float32_t tmp;
129 float32x4_t tmpV, tmpVb;
130 float32x4_t maxValV;
131 uint32x4_t idxV;
132 float32x4_t accumV;
133 float32x2_t accumV2;
134
135 const float32_t *pIn;
136 uint32_t blkCnt;
137 float32_t accum;
138
139 pIn = in;
140
141 blkCnt = blockSize;
142
143 if (blockSize <= 3)
144 {
145 maxVal = *pIn++;
146 blkCnt--;
147
148 while(blkCnt > 0)
149 {
150 tmp = *pIn++;
151
152 if (tmp > maxVal)
153 {
154 maxVal = tmp;
155 }
156 blkCnt--;
157 }
158 }
159 else
160 {
161 maxValV = vld1q_f32(pIn);
162 pIn += 4;
163 blkCnt = (blockSize - 4) >> 2;
164
165 while(blkCnt > 0)
166 {
167 tmpVb = vld1q_f32(pIn);
168 pIn += 4;
169
170 idxV = vcgtq_f32(tmpVb, maxValV);
171 maxValV = vbslq_f32(idxV, tmpVb, maxValV );
172
173 blkCnt--;
174 }
175
176 accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV));
177 accumV2 = vpmax_f32(accumV2,accumV2);
178 maxVal = vget_lane_f32(accumV2, 0) ;
179
180 blkCnt = (blockSize - 4) & 3;
181
182 while(blkCnt > 0)
183 {
184 tmp = *pIn++;
185
186 if (tmp > maxVal)
187 {
188 maxVal = tmp;
189 }
190 blkCnt--;
191 }
192
193 }
194
195
196
197 maxValV = vdupq_n_f32(maxVal);
198 pIn = in;
199 accum = 0;
200 accumV = vdupq_n_f32(0.0f);
201
202 blkCnt = blockSize >> 2;
203
204 while(blkCnt > 0)
205 {
206 tmpV = vld1q_f32(pIn);
207 pIn += 4;
208 tmpV = vsubq_f32(tmpV, maxValV);
209 tmpV = vexpq_f32(tmpV);
210 accumV = vaddq_f32(accumV, tmpV);
211
212 blkCnt--;
213
214 }
215 accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV));
216 accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1);
217
218 blkCnt = blockSize & 0x3;
219 while(blkCnt > 0)
220 {
221 tmp = *pIn++;
222 accum += expf(tmp - maxVal);
223 blkCnt--;
224
225 }
226
227 accum = maxVal + logf(accum);
228
229 return(accum);
230 }
231 #else
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)232 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
233 {
234 float32_t maxVal;
235 float32_t tmp;
236 const float32_t *pIn;
237 uint32_t blkCnt;
238 float32_t accum;
239
240 pIn = in;
241 blkCnt = blockSize;
242
243 maxVal = *pIn++;
244 blkCnt--;
245
246 while(blkCnt > 0)
247 {
248 tmp = *pIn++;
249
250 if (tmp > maxVal)
251 {
252 maxVal = tmp;
253 }
254 blkCnt--;
255
256 }
257
258 blkCnt = blockSize;
259 pIn = in;
260 accum = 0;
261 while(blkCnt > 0)
262 {
263 tmp = *pIn++;
264 accum += expf(tmp - maxVal);
265 blkCnt--;
266
267 }
268 accum = maxVal + logf(accum);
269
270 return(accum);
271 }
272 #endif
273 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
274
275 /**
276 * @} end of LogSumExp group
277 */
278