1 /*
2 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_softmax_common_s8.c
22 * Description: Softmax with s8 input and output of s8 or s16.
23 *
24 * $Date: 5 January 2023
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnsupportfunctions.h"
31
32 #define ACCUM_BITS 12
33
34 /**
35 * @ingroup groupSupport
36 */
37
38 /**
39 * @defgroup supportSoftmax Softmax
40 *
41 * Support functions for Softmax
42 *
43 */
44
45 /**
46 * @addtogroup supportSoftmax
47 * @{
48 */
49
50 /*
51 * Softmax function with s8 input and output of s8 or s16.
52 *
53 * Refer header file for details.
54 *
55 */
arm_nn_softmax_common_s8(const int8_t * input,const int32_t num_rows,const int32_t row_size,const int32_t mult,const int32_t shift,const int32_t diff_min,const bool int16_output,void * output)56 void arm_nn_softmax_common_s8(const int8_t *input,
57 const int32_t num_rows,
58 const int32_t row_size,
59 const int32_t mult,
60 const int32_t shift,
61 const int32_t diff_min,
62 const bool int16_output,
63 void *output)
64 {
65 const int32_t mask = (1 << shift);
66
67 int32_t col = 0;
68 int32_t row_idx;
69
70 for (row_idx = 0; row_idx < num_rows; ++row_idx)
71 {
72 // Find the maximum value in order to ensure numerical stability
73 int8_t max = *input;
74
75 for (col = 1; col < row_size; ++col)
76 {
77 max = MAX(max, input[col]);
78 }
79
80 int32_t diff = 0;
81 int32_t sum = 0;
82
83 for (col = 0; col < row_size; ++col)
84 {
85 diff = input[col] - max;
86 if (diff >= diff_min)
87 {
88 sum += DIV_POW2(EXP_ON_NEG(MUL_SAT(diff * mask, mult)), ACCUM_BITS);
89 }
90 }
91
92 const int32_t headroom = CLZ(sum);
93 const int32_t shifted_scale = ONE_OVER1((sum > 0 ? sum << headroom : 0) - (1 << 31));
94 int32_t bits_over_unit;
95
96 if (int16_output)
97 {
98 int16_t *output_s16 = (int16_t *)output + row_idx * row_size;
99
100 bits_over_unit = ACCUM_BITS - headroom + 15;
101
102 for (col = 0; col < row_size; ++col)
103 {
104 diff = input[col] - max;
105
106 if (diff >= diff_min)
107 {
108 const int32_t res =
109 DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit) +
110 NN_Q15_MIN;
111 output_s16[col] = (int16_t)CLAMP(res, (int32_t)NN_Q15_MAX, (int32_t)NN_Q15_MIN);
112 }
113 else
114 {
115 output_s16[col] = NN_Q15_MIN;
116 }
117 }
118 }
119 else
120 {
121 int8_t *output_s8 = (int8_t *)output + row_idx * row_size;
122
123 bits_over_unit = ACCUM_BITS - headroom + 23;
124
125 for (col = 0; col < row_size; ++col)
126 {
127 diff = input[col] - max;
128 if (diff >= diff_min)
129 {
130 const int32_t res =
131 DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit) +
132 NN_Q7_MIN;
133 output_s8[col] = (int8_t)CLAMP(res, (int32_t)NN_Q7_MAX, (int32_t)NN_Q7_MIN);
134 }
135 else
136 {
137 output_s8[col] = NN_Q7_MIN;
138 }
139 }
140 }
141
142 input += row_size;
143 }
144 }
145
146 /**
147 * @} end of Doxygen group
148 */
149