1 /*
2  * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_lstm_unidirectional_s16_s8.c
22  * Description:  S8 LSTM function with S16 gate output
23  *
24  * $Date:        4 November 2022
25  * $Revision:    V.1.0.0
26  *
27  * Target Processor:  Cortex-M processors
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup Public
36  */
37 
38 /**
39  * @addtogroup LSTM
40  * @{
41  */
42 
43 /*
44  * S8 LSTM function for TensorFlow Lite with S16 gate output
45  *
46  * Refer to header file for details.
47  *
48  */
49 
50 #include "arm_nnfunctions.h"
51 #include "arm_nnsupportfunctions.h"
52 
53 /*
54  * LSTM unidirectional function with 8 bit input and output and 16 bit weights
55  *
56  * Refer header file for details.
57  *
58  */
arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context * scratch_buffers,const int8_t * input_data,const cmsis_nn_lstm_dims * lstm_dims,const int8_t * in_to_in_weights,const int8_t * in_to_forget_weights,const int8_t * in_to_cell_weights,const int8_t * in_to_out_weights,const int8_t * recurrent_to_in_weights,const int8_t * recurrent_to_forget_weights,const int8_t * recurrent_to_cell_weights,const int8_t * recurrent_to_out_weights,const int16_t * cell_to_in_weights,const int16_t * cell_to_forget_weights,const int16_t * cell_to_out_weights,const int8_t * projection_weights,const cmsis_nn_lstm_params * lstm,int8_t * output_state,int16_t * cell_state,int8_t * output_data)59 arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers,
60                                                    const int8_t *input_data,
61                                                    const cmsis_nn_lstm_dims *lstm_dims,
62                                                    const int8_t *in_to_in_weights,
63                                                    const int8_t *in_to_forget_weights,
64                                                    const int8_t *in_to_cell_weights,
65                                                    const int8_t *in_to_out_weights,
66                                                    const int8_t *recurrent_to_in_weights,
67                                                    const int8_t *recurrent_to_forget_weights,
68                                                    const int8_t *recurrent_to_cell_weights,
69                                                    const int8_t *recurrent_to_out_weights,
70                                                    const int16_t *cell_to_in_weights,
71                                                    const int16_t *cell_to_forget_weights,
72                                                    const int16_t *cell_to_out_weights,
73                                                    const int8_t *projection_weights,
74                                                    const cmsis_nn_lstm_params *lstm,
75                                                    int8_t *output_state,
76                                                    int16_t *cell_state,
77                                                    int8_t *output_data)
78 {
79     (void)cell_to_in_weights;
80     (void)cell_to_forget_weights;
81     (void)cell_to_out_weights;
82 
83     const int32_t num_batch = lstm_dims->num_batches;
84     const int32_t num_input = lstm_dims->num_inputs;
85     const int32_t max_time = lstm_dims->max_time;
86 
87     const int32_t num_output = lstm_dims->num_outputs;
88     const int32_t out_batch_leading_dim = num_output;
89 
90     // num_cell = num_output is considered in the code under the assumption that projection is NULL.
91     const int32_t num_cell = num_output;
92 
93     if (projection_weights != NULL)
94     {
95         return ARM_CMSIS_NN_ARG_ERROR;
96     }
97 
98     if (lstm->i2f_effective_bias == NULL || lstm->i2c_effective_bias == NULL || lstm->i2o_effective_bias == NULL)
99     {
100         return ARM_CMSIS_NN_ARG_ERROR;
101     }
102 
103     if (lstm->r2f_effective_bias == NULL || lstm->r2c_effective_bias == NULL || lstm->r2o_effective_bias == NULL)
104     {
105         return ARM_CMSIS_NN_ARG_ERROR;
106     }
107 
108     if (lstm->i2i_effective_bias == NULL || lstm->r2i_effective_bias == NULL)
109     {
110         return ARM_CMSIS_NN_ARG_ERROR;
111     }
112 
113     if (lstm->time_major)
114     {
115         const int32_t in_step = num_batch * num_input;
116         const int32_t out_step = num_batch * out_batch_leading_dim;
117         for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
118         {
119             arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + i_max_time * in_step,
120                                                                  in_to_in_weights,
121                                                                  in_to_forget_weights,
122                                                                  in_to_cell_weights,
123                                                                  in_to_out_weights,
124                                                                  recurrent_to_in_weights,
125                                                                  recurrent_to_forget_weights,
126                                                                  recurrent_to_cell_weights,
127                                                                  recurrent_to_out_weights,
128                                                                  lstm,
129                                                                  num_batch,
130                                                                  num_cell,
131                                                                  num_input,
132                                                                  num_output,
133                                                                  output_state,
134                                                                  cell_state,
135                                                                  output_data + i_max_time * out_step,
136                                                                  scratch_buffers);
137             if (status != ARM_CMSIS_NN_SUCCESS)
138             {
139                 return status;
140             }
141         }
142     }
143     else
144     {
145         for (int i_num_batch = 0; i_num_batch < num_batch; i_num_batch++)
146         {
147             const int32_t in_step = num_input;
148             const int32_t out_step = out_batch_leading_dim;
149             for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
150             {
151                 const int32_t time_offset = i_num_batch * max_time + i_max_time;
152 
153                 arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + time_offset * in_step,
154                                                                      in_to_in_weights,
155                                                                      in_to_forget_weights,
156                                                                      in_to_cell_weights,
157                                                                      in_to_out_weights,
158                                                                      recurrent_to_in_weights,
159                                                                      recurrent_to_forget_weights,
160                                                                      recurrent_to_cell_weights,
161                                                                      recurrent_to_out_weights,
162                                                                      lstm,
163                                                                      /*num_batch=*/1,
164                                                                      num_cell,
165                                                                      num_input,
166                                                                      num_output,
167                                                                      output_state + i_num_batch * out_batch_leading_dim,
168                                                                      cell_state + i_num_batch * num_cell,
169                                                                      output_data + time_offset * out_step,
170                                                                      scratch_buffers);
171                 if (status != ARM_CMSIS_NN_SUCCESS)
172                 {
173                     return status;
174                 }
175             }
176         }
177     }
178 
179     return ARM_CMSIS_NN_SUCCESS;
180 }
181 
182 /**
183  * @} end of LSTM group
184  */
185