1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_logsumexp_f16.c
4 * Description: LogSumExp
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/statistics_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33 #include <limits.h>
34 #include <math.h>
35
36
37 /**
38 * @addtogroup LogSumExp
39 * @{
40 */
41
42
43 /**
44 * @brief Computation of the LogSumExp
45 *
46 * In probabilistic computations, the dynamic of the probability values can be very
47 * wide because they come from gaussian functions.
48 * To avoid underflow and overflow issues, the values are represented by their log.
49 * In this representation, multiplying the original exp values is easy : their logs are added.
50 * But adding the original exp values is requiring some special handling and it is the
51 * goal of the LogSumExp function.
52 *
53 * If the values are x1...xn, the function is computing:
54 *
55 * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
56 * rounding issues are minimised.
57 *
58 * The max xm of the values is extracted and the function is computing:
59 * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
60 *
61 * @param[in] *in Pointer to an array of input values.
62 * @param[in] blockSize Number of samples in the input array.
63 * @return LogSumExp
64 *
65 */
66
67 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
68
69 #include "arm_helium_utils.h"
70 #include "arm_vec_math_f16.h"
71
arm_logsumexp_f16(const float16_t * in,uint32_t blockSize)72 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
73 {
74 float16_t maxVal;
75 const float16_t *pIn;
76 int32_t blkCnt;
77 _Float16 accum=0.0f16;
78 _Float16 tmp;
79
80
81 arm_max_no_idx_f16((float16_t *) in, blockSize, &maxVal);
82
83
84 blkCnt = blockSize;
85 pIn = in;
86
87
88 f16x8_t vSum = vdupq_n_f16(0.0f16);
89 blkCnt = blockSize >> 3;
90 while(blkCnt > 0)
91 {
92 f16x8_t vecIn = vld1q(pIn);
93 f16x8_t vecExp;
94
95 vecExp = vexpq_f16(vsubq_n_f16(vecIn, maxVal));
96
97 vSum = vaddq_f16(vSum, vecExp);
98
99 /*
100 * Decrement the blockSize loop counter
101 * Advance vector source and destination pointers
102 */
103 pIn += 8;
104 blkCnt --;
105 }
106
107 /* sum + log */
108 accum = vecAddAcrossF16Mve(vSum);
109
110 blkCnt = blockSize & 0x7;
111 while(blkCnt > 0)
112 {
113 tmp = *pIn++;
114 accum += expf(tmp - maxVal);
115 blkCnt--;
116
117 }
118
119 accum = maxVal + logf(accum);
120
121 return (accum);
122 }
123
124 #else
arm_logsumexp_f16(const float16_t * in,uint32_t blockSize)125 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
126 {
127 _Float16 maxVal;
128 _Float16 tmp;
129 const float16_t *pIn;
130 uint32_t blkCnt;
131 _Float16 accum;
132
133 pIn = in;
134 blkCnt = blockSize;
135
136 maxVal = *pIn++;
137 blkCnt--;
138
139 while(blkCnt > 0)
140 {
141 tmp = *pIn++;
142
143 if (tmp > maxVal)
144 {
145 maxVal = tmp;
146 }
147 blkCnt--;
148
149 }
150
151 blkCnt = blockSize;
152 pIn = in;
153 accum = 0;
154 while(blkCnt > 0)
155 {
156 tmp = *pIn++;
157 accum += expf(tmp - maxVal);
158 blkCnt--;
159
160 }
161 accum = maxVal + logf(accum);
162
163 return(accum);
164 }
165 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
166
167 /**
168 * @} end of LogSumExp group
169 */
170
171 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
172
173