1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_logsumexp_f32.c
4  * Description:  LogSumExp
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/statistics_functions.h"
30 #include <limits.h>
31 #include <math.h>
32 
33 
34 /**
35  * @addtogroup LogSumExp
36  * @{
37  */
38 
39 
40 /**
41  * @brief Computation of the LogSumExp
42  *
43  * In probabilistic computations, the dynamic of the probability values can be very
44  * wide because they come from gaussian functions.
45  * To avoid underflow and overflow issues, the values are represented by their log.
46  * In this representation, multiplying the original exp values is easy : their logs are added.
47  * But adding the original exp values is requiring some special handling and it is the
48  * goal of the LogSumExp function.
49  *
50  * If the values are x1...xn, the function is computing:
51  *
52  * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
53  * rounding issues are minimised.
54  *
55  * The max xm of the values is extracted and the function is computing:
56  * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
57  *
58  * @param[in]  *in         Pointer to an array of input values.
59  * @param[in]  blockSize   Number of samples in the input array.
60  * @return LogSumExp
61  *
62  */
63 
64 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
65 
66 #include "arm_helium_utils.h"
67 #include "arm_vec_math.h"
68 
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)69 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
70 {
71     float32_t       maxVal;
72     const float32_t *pIn;
73     int32_t         blkCnt;
74     float32_t       accum=0.0f;
75     float32_t       tmp;
76 
77 
78     arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal);
79 
80 
81     blkCnt = blockSize;
82     pIn = in;
83 
84 
85     f32x4_t         vSum = vdupq_n_f32(0.0f);
86     blkCnt = blockSize >> 2;
87     while(blkCnt > 0)
88     {
89         f32x4_t         vecIn = vld1q(pIn);
90         f32x4_t         vecExp;
91 
92         vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal));
93 
94         vSum = vaddq_f32(vSum, vecExp);
95 
96         /*
97          * Decrement the blockSize loop counter
98          * Advance vector source and destination pointers
99          */
100         pIn += 4;
101         blkCnt --;
102     }
103 
104     /* sum + log */
105     accum = vecAddAcrossF32Mve(vSum);
106 
107     blkCnt = blockSize & 0x3;
108     while(blkCnt > 0)
109     {
110        tmp = *pIn++;
111        accum += expf(tmp - maxVal);
112        blkCnt--;
113 
114     }
115 
116     accum = maxVal + log(accum);
117 
118     return (accum);
119 }
120 
121 #else
122 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
123 
124 #include "NEMath.h"
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)125 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
126 {
127     float32_t maxVal;
128     float32_t tmp;
129     float32x4_t tmpV, tmpVb;
130     float32x4_t maxValV;
131     uint32x4_t idxV;
132     float32x4_t accumV;
133     float32x2_t accumV2;
134 
135     const float32_t *pIn;
136     uint32_t blkCnt;
137     float32_t accum;
138 
139     pIn = in;
140 
141     blkCnt = blockSize;
142 
143     if (blockSize <= 3)
144     {
145       maxVal = *pIn++;
146       blkCnt--;
147 
148       while(blkCnt > 0)
149       {
150          tmp = *pIn++;
151 
152          if (tmp > maxVal)
153          {
154             maxVal = tmp;
155          }
156          blkCnt--;
157       }
158     }
159     else
160     {
161       maxValV = vld1q_f32(pIn);
162       pIn += 4;
163       blkCnt = (blockSize - 4) >> 2;
164 
165       while(blkCnt > 0)
166       {
167          tmpVb = vld1q_f32(pIn);
168          pIn += 4;
169 
170          idxV = vcgtq_f32(tmpVb, maxValV);
171          maxValV = vbslq_f32(idxV, tmpVb, maxValV );
172 
173          blkCnt--;
174       }
175 
176       accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV));
177       accumV2 = vpmax_f32(accumV2,accumV2);
178       maxVal = vget_lane_f32(accumV2, 0) ;
179 
180       blkCnt = (blockSize - 4) & 3;
181 
182       while(blkCnt > 0)
183       {
184          tmp = *pIn++;
185 
186          if (tmp > maxVal)
187          {
188             maxVal = tmp;
189          }
190          blkCnt--;
191       }
192 
193     }
194 
195 
196 
197     maxValV = vdupq_n_f32(maxVal);
198     pIn = in;
199     accum = 0;
200     accumV = vdupq_n_f32(0.0f);
201 
202     blkCnt = blockSize >> 2;
203 
204     while(blkCnt > 0)
205     {
206        tmpV = vld1q_f32(pIn);
207        pIn += 4;
208        tmpV = vsubq_f32(tmpV, maxValV);
209        tmpV = vexpq_f32(tmpV);
210        accumV = vaddq_f32(accumV, tmpV);
211 
212        blkCnt--;
213 
214     }
215     accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV));
216     accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1);
217 
218     blkCnt = blockSize & 0x3;
219     while(blkCnt > 0)
220     {
221        tmp = *pIn++;
222        accum += expf(tmp - maxVal);
223        blkCnt--;
224 
225     }
226 
227     accum = maxVal + logf(accum);
228 
229     return(accum);
230 }
231 #else
arm_logsumexp_f32(const float32_t * in,uint32_t blockSize)232 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
233 {
234     float32_t maxVal;
235     float32_t tmp;
236     const float32_t *pIn;
237     uint32_t blkCnt;
238     float32_t accum;
239 
240     pIn = in;
241     blkCnt = blockSize;
242 
243     maxVal = *pIn++;
244     blkCnt--;
245 
246     while(blkCnt > 0)
247     {
248        tmp = *pIn++;
249 
250        if (tmp > maxVal)
251        {
252           maxVal = tmp;
253        }
254        blkCnt--;
255 
256     }
257 
258     blkCnt = blockSize;
259     pIn = in;
260     accum = 0;
261     while(blkCnt > 0)
262     {
263        tmp = *pIn++;
264        accum += expf(tmp - maxVal);
265        blkCnt--;
266 
267     }
268     accum = maxVal + logf(accum);
269 
270     return(accum);
271 }
272 #endif
273 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
274 
275 /**
276  * @} end of LogSumExp group
277  */
278