1 /*
2 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_softmax_s16.c
22 * Description: S16 softmax function
23 *
24 * $Date: 5 January 2023
25 * $Revision: V.2.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @addtogroup Softmax
36 * @{
37 */
38
arm_softmax_s16(const int16_t * input,const int32_t num_rows,const int32_t row_size,const int32_t mult,const int32_t shift,const cmsis_nn_softmax_lut_s16 * softmax_params,int16_t * output)39 arm_cmsis_nn_status arm_softmax_s16(const int16_t *input,
40 const int32_t num_rows,
41 const int32_t row_size,
42 const int32_t mult,
43 const int32_t shift,
44 const cmsis_nn_softmax_lut_s16 *softmax_params,
45 int16_t *output)
46 {
47 int32_t col = 0;
48 int32_t row_idx;
49
50 if (softmax_params->exp_lut == NULL || softmax_params->one_by_one_lut == NULL)
51 {
52 return ARM_CMSIS_NN_ARG_ERROR;
53 }
54
55 for (row_idx = 0; row_idx < num_rows; ++row_idx)
56 {
57 // Find the maximum value in order to ensure numerical stability
58 int16_t max = *input;
59 for (col = 1; col < row_size; ++col)
60 {
61 max = MAX(max, input[col]);
62 }
63
64 int32_t diff = 0;
65 int32_t sum = 0;
66 int16_t *cached_exp_results = output;
67
68 for (col = 0; col < row_size; ++col)
69 {
70 diff = input[col] - max;
71 const int32_t scaled_diff = arm_nn_requantize(diff, mult, shift);
72 const int32_t symmetric_scaled_diff = scaled_diff + NN_Q15_MAX;
73 const int16_t saturated_symmetric_scaled_diff = MIN(MAX(symmetric_scaled_diff, NN_Q15_MIN), NN_Q15_MAX);
74
75 // Lookup from exp table and cache result for next step
76 const int16_t index = (256 + (saturated_symmetric_scaled_diff >> 7));
77 const int16_t offset = saturated_symmetric_scaled_diff & 0x7f;
78 const int16_t base = softmax_params->exp_lut[index];
79 const int16_t slope = softmax_params->exp_lut[index + 1] - softmax_params->exp_lut[index];
80 const int16_t delta = (slope * offset + 64) >> 7;
81 const int16_t result = (base + delta);
82 cached_exp_results[col] = result;
83
84 sum += cached_exp_results[col];
85 }
86
87 const int32_t headroom = CLZ(sum);
88
89 // Compute the reciprocal 1/sum
90 const int32_t shifted_sum = (((sum) << (headroom - 1)) + (1 << 13)) >> 14;
91
92 // Since LUT computes 1/(1 + x), compute x = (sum - 1) => -65536
93 // Since LUT expects a symmetrical input, recenter from [UINT16_MIN, UINT16_MAX] to [INT16_MIN, INT16_MAX] =>
94 // -32768 ==> So in total -65536 -32768 => -98304
95 const int16_t symmetric_shifted_sum = shifted_sum - 98304;
96
97 // Lookup from one by one table
98 const int16_t index = (256 + (symmetric_shifted_sum >> 7));
99 const int16_t offset = symmetric_shifted_sum & 0x7f;
100 const int16_t base = softmax_params->one_by_one_lut[index];
101 const int16_t slope = softmax_params->one_by_one_lut[index + 1] - softmax_params->one_by_one_lut[index];
102 const int16_t delta = (slope * offset + 64) >> 7;
103 const int16_t one_by_one_result = (base + delta);
104
105 for (col = 0; col < row_size; ++col)
106 {
107 const int16_t right_shift = 30 - headroom;
108 int32_t result = (cached_exp_results[col] * one_by_one_result) >> right_shift;
109 result = (result + 1) >> 1; // Last shift position and insert round
110 output[col] = (int16_t)result;
111 }
112
113 output += row_size;
114 input += row_size;
115 }
116
117 return ARM_CMSIS_NN_SUCCESS;
118 }
119
120 /**
121 * @} end of Softmax group
122 */
123