1 /*
2 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_lstm_unidirectional_s16_s8.c
22 * Description: S8 LSTM function with S16 gate output
23 *
24 * $Date: 4 November 2022
25 * $Revision: V.1.0.0
26 *
27 * Target Processor: Cortex-M processors
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup LSTM
40 * @{
41 */
42
43 /*
44 * S8 LSTM function for TensorFlow Lite with S16 gate output
45 *
46 * Refer to header file for details.
47 *
48 */
49
50 #include "arm_nnfunctions.h"
51 #include "arm_nnsupportfunctions.h"
52
53 /*
54 * LSTM unidirectional function with 8 bit input and output and 16 bit weights
55 *
56 * Refer header file for details.
57 *
58 */
arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context * scratch_buffers,const int8_t * input_data,const cmsis_nn_lstm_dims * lstm_dims,const int8_t * in_to_in_weights,const int8_t * in_to_forget_weights,const int8_t * in_to_cell_weights,const int8_t * in_to_out_weights,const int8_t * recurrent_to_in_weights,const int8_t * recurrent_to_forget_weights,const int8_t * recurrent_to_cell_weights,const int8_t * recurrent_to_out_weights,const int16_t * cell_to_in_weights,const int16_t * cell_to_forget_weights,const int16_t * cell_to_out_weights,const int8_t * projection_weights,const cmsis_nn_lstm_params * lstm,int8_t * output_state,int16_t * cell_state,int8_t * output_data)59 arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers,
60 const int8_t *input_data,
61 const cmsis_nn_lstm_dims *lstm_dims,
62 const int8_t *in_to_in_weights,
63 const int8_t *in_to_forget_weights,
64 const int8_t *in_to_cell_weights,
65 const int8_t *in_to_out_weights,
66 const int8_t *recurrent_to_in_weights,
67 const int8_t *recurrent_to_forget_weights,
68 const int8_t *recurrent_to_cell_weights,
69 const int8_t *recurrent_to_out_weights,
70 const int16_t *cell_to_in_weights,
71 const int16_t *cell_to_forget_weights,
72 const int16_t *cell_to_out_weights,
73 const int8_t *projection_weights,
74 const cmsis_nn_lstm_params *lstm,
75 int8_t *output_state,
76 int16_t *cell_state,
77 int8_t *output_data)
78 {
79 (void)cell_to_in_weights;
80 (void)cell_to_forget_weights;
81 (void)cell_to_out_weights;
82
83 const int32_t num_batch = lstm_dims->num_batches;
84 const int32_t num_input = lstm_dims->num_inputs;
85 const int32_t max_time = lstm_dims->max_time;
86
87 const int32_t num_output = lstm_dims->num_outputs;
88 const int32_t out_batch_leading_dim = num_output;
89
90 // num_cell = num_output is considered in the code under the assumption that projection is NULL.
91 const int32_t num_cell = num_output;
92
93 if (projection_weights != NULL)
94 {
95 return ARM_CMSIS_NN_ARG_ERROR;
96 }
97
98 if (lstm->i2f_effective_bias == NULL || lstm->i2c_effective_bias == NULL || lstm->i2o_effective_bias == NULL)
99 {
100 return ARM_CMSIS_NN_ARG_ERROR;
101 }
102
103 if (lstm->r2f_effective_bias == NULL || lstm->r2c_effective_bias == NULL || lstm->r2o_effective_bias == NULL)
104 {
105 return ARM_CMSIS_NN_ARG_ERROR;
106 }
107
108 if (lstm->i2i_effective_bias == NULL || lstm->r2i_effective_bias == NULL)
109 {
110 return ARM_CMSIS_NN_ARG_ERROR;
111 }
112
113 if (lstm->time_major)
114 {
115 const int32_t in_step = num_batch * num_input;
116 const int32_t out_step = num_batch * out_batch_leading_dim;
117 for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
118 {
119 arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + i_max_time * in_step,
120 in_to_in_weights,
121 in_to_forget_weights,
122 in_to_cell_weights,
123 in_to_out_weights,
124 recurrent_to_in_weights,
125 recurrent_to_forget_weights,
126 recurrent_to_cell_weights,
127 recurrent_to_out_weights,
128 lstm,
129 num_batch,
130 num_cell,
131 num_input,
132 num_output,
133 output_state,
134 cell_state,
135 output_data + i_max_time * out_step,
136 scratch_buffers);
137 if (status != ARM_CMSIS_NN_SUCCESS)
138 {
139 return status;
140 }
141 }
142 }
143 else
144 {
145 for (int i_num_batch = 0; i_num_batch < num_batch; i_num_batch++)
146 {
147 const int32_t in_step = num_input;
148 const int32_t out_step = out_batch_leading_dim;
149 for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
150 {
151 const int32_t time_offset = i_num_batch * max_time + i_max_time;
152
153 arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + time_offset * in_step,
154 in_to_in_weights,
155 in_to_forget_weights,
156 in_to_cell_weights,
157 in_to_out_weights,
158 recurrent_to_in_weights,
159 recurrent_to_forget_weights,
160 recurrent_to_cell_weights,
161 recurrent_to_out_weights,
162 lstm,
163 /*num_batch=*/1,
164 num_cell,
165 num_input,
166 num_output,
167 output_state + i_num_batch * out_batch_leading_dim,
168 cell_state + i_num_batch * num_cell,
169 output_data + time_offset * out_step,
170 scratch_buffers);
171 if (status != ARM_CMSIS_NN_SUCCESS)
172 {
173 return status;
174 }
175 }
176 }
177 }
178
179 return ARM_CMSIS_NN_SUCCESS;
180 }
181
182 /**
183 * @} end of LSTM group
184 */
185