1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2020, 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_softmax_u8.c
22 * Description: U8 softmax function
23 *
24 * $Date: 5 January 2023
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 #define ACCUM_BITS 12
35
36 /**
37 * @ingroup Public
38 */
39
40 /**
41 * @addtogroup Softmax
42 * @{
43 */
arm_softmax_u8(const uint8_t * input,const int32_t num_rows,const int32_t row_size,const int32_t mult,const int32_t shift,const int32_t diff_min,uint8_t * output)44 void arm_softmax_u8(const uint8_t *input,
45 const int32_t num_rows,
46 const int32_t row_size,
47 const int32_t mult,
48 const int32_t shift,
49 const int32_t diff_min,
50 uint8_t *output)
51 {
52 const int32_t mask = (1 << shift);
53
54 int32_t col = 0;
55 int32_t row_idx;
56
57 for (row_idx = 0; row_idx < num_rows; ++row_idx)
58 {
59 // Find the maximum value in order to ensure numerical stability
60 uint8_t max = *input;
61
62 for (col = 1; col < row_size; ++col)
63 {
64 max = MAX(max, input[col]);
65 }
66
67 int32_t diff = 0;
68 int32_t sum = 0;
69
70 for (col = 0; col < row_size; ++col)
71 {
72 diff = input[col] - max;
73 if (diff >= diff_min)
74 {
75 sum += DIV_POW2(EXP_ON_NEG(MUL_SAT(diff * mask, mult)), ACCUM_BITS);
76 }
77 }
78
79 const int32_t headroom = CLZ((uint32_t)sum);
80 const int32_t bits_over_unit = ACCUM_BITS - headroom + 23;
81 const int32_t shifted_scale = ONE_OVER1((sum << headroom) - (1 << 31));
82
83 for (col = 0; col < row_size; ++col)
84 {
85 diff = input[col] - max;
86 if (diff >= diff_min)
87 {
88 const int32_t res =
89 DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit);
90 output[col] = (uint8_t)CLAMP(res, (int32_t)255, (int32_t)0);
91 }
92 else
93 {
94 output[col] = 0;
95 }
96 }
97 input += row_size;
98 output += row_size;
99 }
100 }
101 /**
102 * @} end of Softmax group
103 */