1 /*
2 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_lstm_step_s8_s16.c
22 * Description: Update LSTM function for a single iteration step.
23 *
24 * $Date: 9 Februari 2023
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30 #include "arm_nnsupportfunctions.h"
31 /**
32 * @ingroup groupSupport
33 */
34
35 /**
36 * @addtogroup supportLSTM
37 * @{
38 */
39
40 /*
41 * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
42 * Refer to header file for details.
43 */
arm_nn_lstm_step_s8_s16(const int8_t * input,const int8_t * input_to_input_weight,const int8_t * input_to_forget_weight,const int8_t * input_to_cell_weight,const int8_t * input_to_output_weight,const int8_t * recurrent_to_input_weight,const int8_t * recurrent_to_forget_weight,const int8_t * recurrent_to_cell_weight,const int8_t * recurrent_to_output_weight,const cmsis_nn_lstm_params * lstm,const int n_batch,const int n_cell,const int n_input,const int n_output,int8_t * output_state,int16_t * cell_state,int8_t * output,cmsis_nn_lstm_context * scratch_buffers)44 arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
45 const int8_t *input_to_input_weight,
46 const int8_t *input_to_forget_weight,
47 const int8_t *input_to_cell_weight,
48 const int8_t *input_to_output_weight,
49 const int8_t *recurrent_to_input_weight,
50 const int8_t *recurrent_to_forget_weight,
51 const int8_t *recurrent_to_cell_weight,
52 const int8_t *recurrent_to_output_weight,
53 const cmsis_nn_lstm_params *lstm,
54 const int n_batch,
55 const int n_cell,
56 const int n_input,
57 const int n_output,
58 int8_t *output_state,
59 int16_t *cell_state,
60 int8_t *output,
61 cmsis_nn_lstm_context *scratch_buffers)
62 {
63 const int32_t n_block = n_batch * n_cell;
64
65 // Calculate the input gate
66 arm_nn_lstm_calculate_gate_s8_s16(input,
67 input_to_input_weight,
68 lstm->i2i_effective_bias,
69 lstm->input_to_input_scaling,
70 output_state,
71 recurrent_to_input_weight,
72 lstm->r2i_effective_bias,
73 lstm->recurrent_to_input_scaling,
74 n_batch,
75 n_input,
76 n_output,
77 n_cell,
78 ARM_SIGMOID,
79 scratch_buffers->input_gate);
80
81 // Calculate the forget gate
82 arm_nn_lstm_calculate_gate_s8_s16(input,
83 input_to_forget_weight,
84 lstm->i2f_effective_bias,
85 lstm->input_to_forget_scaling,
86 output_state,
87 recurrent_to_forget_weight,
88 lstm->r2f_effective_bias,
89 lstm->recurrent_to_forget_scaling,
90 n_batch,
91 n_input,
92 n_output,
93 n_cell,
94 ARM_SIGMOID,
95 scratch_buffers->forget_gate);
96
97 // Calculate the cell update gate
98 arm_nn_lstm_calculate_gate_s8_s16(input,
99 input_to_cell_weight,
100 lstm->i2c_effective_bias,
101 lstm->input_to_cell_scaling,
102 output_state,
103 recurrent_to_cell_weight,
104 lstm->r2c_effective_bias,
105 lstm->recurrent_to_cell_scaling,
106 n_batch,
107 n_input,
108 n_output,
109 n_cell,
110 ARM_TANH,
111 scratch_buffers->cell_gate);
112
113 // Update the cell state
114 arm_nn_lstm_update_cell_state_s16(n_block,
115 lstm->cell_state_shift,
116 cell_state,
117 scratch_buffers->input_gate,
118 scratch_buffers->forget_gate,
119 scratch_buffers->cell_gate);
120
121 // Calculate the output gate
122 arm_nn_lstm_calculate_gate_s8_s16(input,
123 input_to_output_weight,
124 lstm->i2o_effective_bias,
125 lstm->input_to_output_scaling,
126 output_state,
127 recurrent_to_output_weight,
128 lstm->r2o_effective_bias,
129 lstm->recurrent_to_output_scaling,
130 n_batch,
131 n_input,
132 n_output,
133 n_cell,
134 ARM_SIGMOID,
135 scratch_buffers->output_gate);
136
137 // Update the output state
138 arm_nn_lstm_update_output_s8_s16(n_batch,
139 n_cell,
140 cell_state,
141 lstm->cell_state_shift,
142 scratch_buffers->output_gate,
143 lstm->hidden_scaling,
144 lstm->hidden_offset,
145 output_state,
146 scratch_buffers->input_gate);
147
148 arm_memcpy_s8(output, output_state, n_batch * n_output * sizeof(int8_t));
149
150 return ARM_CMSIS_NN_SUCCESS;
151 }
152 /**
153 * @} end of supportLSTM group
154 */
155