1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nnsupportfunctions.h
22  * Description:  Public header file of support functions for CMSIS NN Library
23  *
24  * $Date:        23 Mars 2023
25  * $Revision:    V.16.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
31 #define _ARM_NNSUPPORTFUNCTIONS_H_
32 
33 #include "Internal/arm_nn_compiler.h"
34 #include "arm_nn_math_types.h"
35 #include "arm_nn_types.h"
36 
37 #include <stdbool.h>
38 
39 #ifdef __cplusplus
40 extern "C" {
41 #endif
42 
43 #define USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims)                                         \
44     (dw_conv_params->ch_mult == 1 && dw_conv_params->dilation.w == 1 && dw_conv_params->dilation.h == 1 &&             \
45      filter_dims->w * filter_dims->h < 512)
46 
47 #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
48 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
49 #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
50 #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
51 #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
52 
53 #define MAX(A, B) ((A) > (B) ? (A) : (B))
54 #define MIN(A, B) ((A) < (B) ? (A) : (B))
55 #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
56 #define REDUCE_MULTIPLIER(_mult) ((_mult < 0x7FFF0000) ? ((_mult + (1 << 15)) >> 16) : 0x7FFF)
57 
58 // Number of channels processed in a block for DW Conv(MVE)
59 // Requirement: Greater than 0 & less than 128
60 // This can be fine tuned to match number of input channels for best performance.
61 // A layer with lower number of channels than CH_IN_BLOCK_MVE will result in higher
62 // scratch buffer usage and a layer with higher number of channels than CH_IN_BLOCK_MVE
63 // will result in lower scratch buffer usage.
64 #define CH_IN_BLOCK_MVE (124)
65 
66 /**
67  * @brief definition to pack four 8 bit values.
68  */
69 #define PACK_S8x4_32x1(v0, v1, v2, v3)                                                                                 \
70     ((((int32_t)(v0) << 0) & (int32_t)0x000000FF) | (((int32_t)(v1) << 8) & (int32_t)0x0000FF00) |                     \
71      (((int32_t)(v2) << 16) & (int32_t)0x00FF0000) | (((int32_t)(v3) << 24) & (int32_t)0xFF000000))
72 
73 /**
74  * @brief definition to pack two 16 bit values.
75  */
76 #define PACK_Q15x2_32x1(v0, v1) (((int32_t)v0 & (int32_t)0xFFFF) | ((int32_t)v1 << 16))
77 
78 /**
79  * @brief Union for SIMD access of q31/s16/s8 types
80  */
81 union arm_nnword
82 {
83     int32_t word;
84     /**< q31 type */
85     int16_t half_words[2];
86     /**< s16 type */
87     int8_t bytes[4];
88     /**< s8 type */
89 };
90 
91 /**
92  * @brief Union for data type long long
93  */
94 struct arm_nn_double
95 {
96     uint32_t low;
97     int32_t high;
98 };
99 
100 union arm_nn_long_long
101 {
102     int64_t long_long;
103     struct arm_nn_double word;
104 };
105 
106 /**
107  * @defgroup groupSupport Private
108  *
109  * Internal Support functions. Not intended to be called direclty by a CMSIS-NN user.
110  *
111  */
112 
113 /**
114  * @defgroup supportConversion Data Conversion
115  *
116  * Perform data type conversion in-between neural network operations
117  *
118  */
119 
120 /**
121  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
122  * @param[in]    src        pointer to the s8 input vector
123  * @param[out]   dst        pointer to the s16 output vector
124  * @param[in]    block_size length of the input vector
125  * @param[in]    offset     s16 offset to be added to each input vector element.
126  *
127  * \par Description:
128  *
129  * Output elements are ordered.
130  * The equation used for the conversion process is:
131  *
132  * <pre>
133  *  dst[n] = (int16_t) src[n] + offset;   0 <= n < block_size.
134  * </pre>
135  *
136  */
137 void arm_q7_to_q15_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
138 
139 #if defined(ARM_MATH_DSP)
140 /**
141  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
142  * @param[in]    src        pointer to the s8 input vector
143  * @param[out]   dst        pointer to the s16 output vector
144  * @param[in]    block_size length of the input vector
145  * @param[in]    offset     s16 offset to be added to each input vector element.
146  *
147  * \par Description:
148  *
149  * No additonal ordering is done with the result that output elements are not in order.
150  * Instead of ABCD order will be ACBD.
151  * Note this is for processors with DSP extension only.
152  * The equation used for the conversion process is:
153  *
154  * <pre>
155  *  dst[n - 0] = (int16_t) src[n - 0] + offset;   0 <= n < block_size.
156  *  dst[n - 1] = (int16_t) src[n - 2] + offset;   0 <= n < block_size.
157  *  dst[n - 2] = (int16_t) src[n - 1] + offset;   0 <= n < block_size.
158  *  dst[n - 3] = (int16_t) src[n - 3] + offset;   0 <= n < block_size.
159  * </pre>
160  *
161  */
162 void arm_s8_to_s16_unordered_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
163 #endif
164 
165 /**
166  * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
167  * @param[in]    row     pointer to row
168  * @param[in]    col     pointer to im2col buffer, always consists of 2 columns.
169  * @param[in]    num_ch   number of channels
170  * @param[in]    out_shift  pointer to per output channel requantization shift parameter.
171  * @param[in]    out_mult   pointer to per output channel requantization multiplier parameter.
172  * @param[in]    out_offset      output tensor offset.
173  * @param[in]    activation_min   minimum value to clamp the output to. Range : int8
174  * @param[in]    activation_max   maximum value to clamp the output to. Range : int8
175  * @param[in]    kernel_size   number of elements in one column.
176  * @param[in]    output_bias per output channel bias. Range : int32
177  * @param[out]   out         pointer to output
178  * @return     The function returns one of the two
179  *              1. The incremented output pointer for a successful operation or
180  *              2. NULL if implementation is not available.
181  *
182  * @details     Supported framework: TensorFlow Lite micro.
183  */
184 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
185                                       const int16_t *col,
186                                       const uint16_t num_ch,
187                                       const int32_t *out_shift,
188                                       const int32_t *out_mult,
189                                       const int32_t out_offset,
190                                       const int32_t activation_min,
191                                       const int32_t activation_max,
192                                       const uint16_t kernel_size,
193                                       const int32_t *const output_bias,
194                                       int8_t *out);
195 
196 /**
197  * @brief General Matrix-multiplication function with per-channel requantization.
198  * @param[in]       input_row    pointer to row operand
199  * @param[in]       input_col    pointer to col operand
200  * @param[in]       output_ch    number of rows of input_row
201  * @param[in]       col_batches  number of column batches. Range: 1 to 4
202  * @param[in]       output_shift  pointer to per output channel requantization shift parameter.
203  * @param[in]       output_mult   pointer to per output channel requantization multiplier parameter.
204  * @param[in]       out_offset    output tensor offset.
205  * @param[in]       col_offset    input tensor(col) offset.
206  * @param[in]       row_offset    kernel offset(row). Not used.
207  * @param[in]       out_activation_min   minimum value to clamp the output to. Range : int8
208  * @param[in]       out_activation_max   maximum value to clamp the output to. Range : int8
209  * @param[in]       row_len       number of elements in each row
210  * @param[in]       bias          per output channel bias. Range : int32
211  * @param[in,out]   out           pointer to output
212  * @return     The function returns one of the two
213  *              1. The incremented output pointer for a successful operation or
214  *              2. NULL if implementation is not available.
215  *
216  * @details   Supported framework: TensorFlow Lite
217  */
218 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
219                            const int8_t *input_col,
220                            const uint16_t output_ch,
221                            const uint16_t col_batches,
222                            const int32_t *output_shift,
223                            const int32_t *output_mult,
224                            const int32_t out_offset,
225                            const int32_t col_offset,
226                            const int32_t row_offset,
227                            const int16_t out_activation_min,
228                            const int16_t out_activation_max,
229                            const uint16_t row_len,
230                            const int32_t *const bias,
231                            int8_t *out);
232 /**
233  * @brief Matrix-multiplication function for convolution with per-channel requantization for 16 bits convolution.
234  * @param[in]       input_a     pointer to operand A
235  * @param[in]       input_b     pointer to operand B, always consists of 2 vectors.
236  * @param[in]       output_ch   number of rows of A
237  * @param[in]       out_shift  pointer to per output channel requantization shift parameter.
238  * @param[in]       out_mult   pointer to per output channel requantization multiplier parameter.
239  * @param[in]       activation_min   minimum value to clamp the output to. Range : int16
240  * @param[in]       activation_max   maximum value to clamp the output to. Range : int16
241  * @param[in]       num_col_a   number of columns of A
242  * @param[in]       output_bias per output channel bias. Range : int64
243  * @param[in,out]   out_0       pointer to output
244  * @return     The function returns one of the two
245  *              1. The incremented output pointer for a successful operation or
246  *              2. NULL if implementation is not available.
247  *
248  * @details   This function does the matrix multiplication of weight matrix for all output channels
249  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
250  *            clamped in the range provided by activation min and max.
251  *            Supported framework: TensorFlow Lite micro.
252  */
253 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
254                                     const int16_t *input_b,
255                                     const int32_t output_ch,
256                                     const int32_t *out_shift,
257                                     const int32_t *out_mult,
258                                     const int16_t activation_min,
259                                     const int16_t activation_max,
260                                     const int32_t num_col_a,
261                                     const int64_t *const output_bias,
262                                     int16_t *out_0);
263 
264 /**
265  * @brief General Vector by Matrix multiplication with requantization and storage of result.
266  * @param[in]       row_elements          number of row elements
267  * @param[in]       skipped_row_elements  number of row elements skipped due to padding.
268  *                                        row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
269  * @param[in]       row_base_ref          pointer to row operand
270  * @param[in]       col_base_ref          pointer to col operand
271  * @param[out]      out_ch                Number of output channels
272  * @param[in]       conv_params           Pointer to convolution parameters like offsets and activation values
273  * @param[in]       quant_params          Pointer to per-channel quantization parameters
274  * @param[in]       bias                  Pointer to optional per-channel bias
275  * @param[out]      output                Pointer to output where int8 results are stored.
276  * @return     The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
277  *             scaled result is stored in memory.
278  *
279  * @details Pseudo-code
280  *      *output = 0
281  *      sum_col = 0
282  *      for (j = 0; j < out_ch; j++)
283  *      for (i = 0; i < row_elements; i++)
284  *          *output += row_base_ref[i] * col_base_ref[i]
285  *          sum_col += col_base_ref[i]
286  *      scale sum_col using quant_params and bias
287  *      store result in 'output'
288  *
289  *
290  */
291 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
292                                               const int32_t skipped_row_elements,
293                                               const int8_t *row_base_ref,
294                                               const int8_t *col_base_ref,
295                                               const int32_t out_ch,
296                                               const cmsis_nn_conv_params *conv_params,
297                                               const cmsis_nn_per_channel_quant_params *quant_params,
298                                               const int32_t *bias,
299                                               int8_t *output);
300 
301 /**
302  * @brief Matrix-multiplication with requantization & activation function for four rows and one column
303  * @param[in]       row_elements  number of row elements
304  * @param[in]       offset        offset between rows. Can be the same as row_elements.
305  *                                For e.g, in a 1x1 conv scenario with stride as 1.
306  * @param[in]       row_base      pointer to row operand
307  * @param[in]       col_base      pointer to col operand
308  * @param[in]       out_ch        Number of output channels
309  * @param[in]       conv_params   Pointer to convolution parameters like offsets and activation values
310  * @param[in]       quant_params  Pointer to per-channel quantization parameters
311  * @param[in]       bias          Pointer to per-channel bias
312  * @param[out]      output        Pointer to output where int8 results are stored.
313  *
314  * @return     The function returns the updated output pointer or NULL if implementation is not available.
315  *
316  * @details Compliant to TFLM int8 specification. MVE implementation only
317  */
318 int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
319                                   const int32_t offset,
320                                   const int8_t *row_base,
321                                   const int8_t *col_base,
322                                   const int32_t out_ch,
323                                   const cmsis_nn_conv_params *conv_params,
324                                   const cmsis_nn_per_channel_quant_params *quant_params,
325                                   const int32_t *bias,
326                                   int8_t *output);
327 
328 /**
329  * @brief General Matrix-multiplication function with per-channel requantization.
330  *        This function assumes:
331  *        - LHS input matrix NOT transposed (nt)
332  *        - RHS input matrix transposed (t)
333  *
334  *  @note This operation also performs the broadcast bias addition before the requantization
335  *
336  * @param[in]  lhs                Pointer to the LHS input matrix
337  * @param[in]  rhs                Pointer to the RHS input matrix
338  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
339  *                                output columns (or RHS input rows)
340  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
341  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
342  *                                The length of this vector is equal to the number of output columns (or RHS input
343  *                                rows)
344  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
345  *                                of this vector is equal to the number of output columns (or RHS input rows)
346  * @param[in]  lhs_rows           Number of LHS input rows
347  * @param[in]  rhs_rows           Number of RHS input rows
348  * @param[in]  rhs_cols           Number of LHS/RHS input columns
349  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
350  * @param[in]  dst_offset         Offset to be applied the output result
351  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
352  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
353  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
354  *
355  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
356  *
357  */
358 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
359                                             const int8_t *rhs,
360                                             const int32_t *bias,
361                                             int8_t *dst,
362                                             const int32_t *dst_multipliers,
363                                             const int32_t *dst_shifts,
364                                             const int32_t lhs_rows,
365                                             const int32_t rhs_rows,
366                                             const int32_t rhs_cols,
367                                             const int32_t lhs_offset,
368                                             const int32_t dst_offset,
369                                             const int32_t activation_min,
370                                             const int32_t activation_max,
371                                             const int32_t lhs_cols_offset);
372 
373 /**
374  * @brief s8 Vector by Matrix (transposed) multiplication
375  *
376  * @param[in]      lhs             Input left-hand side vector
377  * @param[in]      rhs             Input right-hand side matrix (transposed)
378  * @param[in]      bias            Input bias
379  * @param[out]     dst             Output vector
380  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
381  *                                 Range: -127 to 128
382  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
383  * @param[in]      dst_multiplier  Output multiplier
384  * @param[in]      dst_shift       Output shift
385  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
386  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
387  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
388  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
389  * @param[in]      address_offset  Memory position offset for dst. First output is stored at 'dst', the
390  *                                 second at 'dst + address_offset' and so on. Default value is typically 1.
391  *
392  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
393  *
394  */
395 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
396                                              const int8_t *rhs,
397                                              const int32_t *bias,
398                                              int8_t *dst,
399                                              const int32_t lhs_offset,
400                                              const int32_t dst_offset,
401                                              const int32_t dst_multiplier,
402                                              const int32_t dst_shift,
403                                              const int32_t rhs_cols,
404                                              const int32_t rhs_rows,
405                                              const int32_t activation_min,
406                                              const int32_t activation_max,
407                                              const int32_t address_offset);
408 
409 /**
410  * @brief s16 Vector by Matrix (transposed) multiplication
411  *
412  * @param[in]      lhs             Input left-hand side vector
413  * @param[in]      rhs             Input right-hand side matrix (transposed)
414  * @param[in]      bias            Input bias
415  * @param[out]     dst             Output vector
416  * @param[in]      dst_multiplier  Output multiplier
417  * @param[in]      dst_shift       Output shift
418  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
419  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
420  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
421  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
422  *
423  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
424  *
425  */
426 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
427                                               const int8_t *rhs,
428                                               const int64_t *bias,
429                                               int16_t *dst,
430                                               const int32_t dst_multiplier,
431                                               const int32_t dst_shift,
432                                               const int32_t rhs_cols,
433                                               const int32_t rhs_rows,
434                                               const int32_t activation_min,
435                                               const int32_t activation_max);
436 
437 /**
438  * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
439  *
440  * @param[in]      lhs             Input left-hand side vector
441  * @param[in]      rhs             Input right-hand side matrix (transposed)
442  * @param[out]     dst             Output vector
443  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side
444  *                                 vector. Range: -127 to 128
445  * @param[in]      scatter_offset  Address offset for dst. First output is stored at 'dst', the
446  *                                 second at 'dst + scatter_offset' and so on.
447  * @param[in]      dst_multiplier  Output multiplier
448  * @param[in]      dst_shift       Output shift
449  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
450  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
451  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
452  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
453  *
454  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
455  *
456  */
457 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
458                                                   const int8_t *rhs,
459                                                   int16_t *dst,
460                                                   const int32_t lhs_offset,
461                                                   const int32_t scatter_offset,
462                                                   const int32_t dst_multiplier,
463                                                   const int32_t dst_shift,
464                                                   const int32_t rhs_cols,
465                                                   const int32_t rhs_rows,
466                                                   const int32_t activation_min,
467                                                   const int32_t activation_max);
468 
469 /**
470  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
471  *        the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
472  *
473  * @param[in]      lhs             Input left-hand side matrix
474  * @param[in]      rhs             Input right-hand side matrix (transposed)
475  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
476  * @param[in]      active_ch       Subset of total_ch processed
477  * @param[in]      total_ch        Number of channels in LHS/RHS
478  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels
479  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels
480  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
481  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
482  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
483  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
484  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels
485  * @param[in]      out             Output pointer
486  *
487  * @return         The function returns one of the two
488  *                  - Updated output pointer if an implementation is available
489  *                  - NULL if no implementation is available.
490  *
491  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
492  * out for the following.
493  *                  - Output shift
494  *                  - Output multiplier
495  *                  - Output bias
496  *                  - rhs
497  */
498 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_padded_s8(const int8_t *lhs,
499                                                          const int8_t *rhs,
500                                                          const int32_t lhs_offset,
501                                                          const int32_t active_ch,
502                                                          const int32_t total_ch,
503                                                          const int32_t *out_shift,
504                                                          const int32_t *out_mult,
505                                                          const int32_t out_offset,
506                                                          const int32_t activation_min,
507                                                          const int32_t activation_max,
508                                                          const uint16_t row_x_col,
509                                                          const int32_t *const output_bias,
510                                                          int8_t *out);
511 
512 /**
513  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
514  *        Dimensions are the same for lhs and rhs.
515  *
516  * @param[in]      lhs             Input left-hand side matrix
517  * @param[in]      rhs             Input right-hand side matrix (transposed)
518  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
519  * @param[in]      active_ch       Subset of total_ch processed
520  * @param[in]      total_ch        Number of channels in LHS/RHS
521  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
522  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
523  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
524  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
525  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
526  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
527  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
528  * @param[in]      out             Output pointer
529  *
530  * @return         The function returns one of the two
531  *                  - Updated output pointer if an implementation is available
532  *                  - NULL if no implementation is available.
533  *
534  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
535  * out for the following.
536  *                  - Output shift
537  *                  - Output multiplier
538  *                  - Output bias
539  *                  - rhs
540  */
541 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s8(const int8_t *lhs,
542                                                   const int8_t *rhs,
543                                                   const int32_t lhs_offset,
544                                                   const int32_t active_ch,
545                                                   const int32_t total_ch,
546                                                   const int32_t *out_shift,
547                                                   const int32_t *out_mult,
548                                                   const int32_t out_offset,
549                                                   const int32_t activation_min,
550                                                   const int32_t activation_max,
551                                                   const uint16_t row_x_col,
552                                                   const int32_t *const output_bias,
553                                                   int8_t *out);
554 
555 /**
556  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
557  *        Dimensions are the same for lhs and rhs.
558  *
559  * @param[in]      lhs             Input left-hand side matrix
560  * @param[in]      rhs             Input right-hand side matrix (transposed)
561  * @param[in]      num_ch          Number of channels in LHS/RHS
562  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
563  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
564  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
565  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
566  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
567  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
568  * @param[in]      out             Output pointer
569  *
570  * @return         The function returns one of the two
571  *                  - Updated output pointer if an implementation is available
572  *                  - NULL if no implementation is available.
573  *
574  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
575  * out for the following.
576  *                  - Output shift
577  *                  - Output multiplier
578  *                  - Output bias
579  *                  - rhs
580  */
581 int16_t *arm_nn_depthwise_conv_nt_t_s16(const int16_t *lhs,
582                                         const int8_t *rhs,
583                                         const uint16_t num_ch,
584                                         const int32_t *out_shift,
585                                         const int32_t *out_mult,
586                                         const int32_t activation_min,
587                                         const int32_t activation_max,
588                                         const uint16_t row_x_col,
589                                         const int64_t *const output_bias,
590                                         int16_t *out);
591 
592 /**
593   @brief         Read 2 s16 elements and post increment pointer.
594   @param[in]     in_q15   Pointer to pointer that holds address of input.
595   @return        q31 value
596  */
arm_nn_read_q15x2_ia(const int16_t ** in_q15)597 __STATIC_FORCEINLINE int32_t arm_nn_read_q15x2_ia(const int16_t **in_q15)
598 {
599     int32_t val;
600 
601     memcpy(&val, *in_q15, 4);
602     *in_q15 += 2;
603 
604     return (val);
605 }
606 
607 /**
608   @brief         Read 4 s8 from s8 pointer and post increment pointer.
609   @param[in]     in_s8       Pointer to pointer that holds address of input.
610   @return        q31 value
611  */
arm_nn_read_s8x4_ia(const int8_t ** in_s8)612 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4_ia(const int8_t **in_s8)
613 {
614     int32_t val;
615     memcpy(&val, *in_s8, 4);
616     *in_s8 += 4;
617 
618     return (val);
619 }
620 
621 /**
622   @brief         Read 2 int16 values from int16 pointer.
623   @param[in]     in     pointer to address of input.
624   @return        s32    value
625  */
arm_nn_read_s16x2(const int16_t * in)626 __STATIC_FORCEINLINE int32_t arm_nn_read_s16x2(const int16_t *in)
627 {
628     int32_t val;
629     memcpy(&val, in, 4);
630 
631     return (val);
632 }
633 
634 /**
635   @brief         Read 4 s8 values.
636   @param[in]     in_s8       pointer to address of input.
637   @return        s32 value
638  */
arm_nn_read_s8x4(const int8_t * in_s8)639 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4(const int8_t *in_s8)
640 {
641     int32_t val;
642     memcpy(&val, in_s8, 4);
643 
644     return (val);
645 }
646 
647 /**
648   @brief         Write four s8 to s8 pointer and increment pointer afterwards.
649   @param[in]     in       Double pointer to input value
650   @param[in]     value    Four bytes to copy
651  */
arm_nn_write_s8x4_ia(int8_t ** in,int32_t value)652 __STATIC_FORCEINLINE void arm_nn_write_s8x4_ia(int8_t **in, int32_t value)
653 {
654     memcpy(*in, &value, 4);
655     *in += 4;
656 }
657 
658 /**
659  * @brief           memset optimized for MVE
660  * @param[in, out]  dst         Destination pointer
661  * @param[in]       val         Value to set
662  * @param[in]       block_size  Number of bytes to copy.
663  *
664  */
arm_memset_s8(int8_t * dst,const int8_t val,uint32_t block_size)665 __STATIC_FORCEINLINE void arm_memset_s8(int8_t *dst, const int8_t val, uint32_t block_size)
666 {
667 #if defined(ARM_MATH_MVEI)
668     __asm volatile("   vdup.8                  q0, %[set_val]             \n"
669                    "   wlstp.8                 lr, %[cnt], 1f             \n"
670                    "2:                                                    \n"
671                    "   vstrb.8                 q0, [%[in]], #16            \n"
672                    "   letp                    lr, 2b                     \n"
673                    "1:                                                    \n"
674                    : [in] "+r"(dst)
675                    : [cnt] "r"(block_size), [set_val] "r"(val)
676                    : "q0", "memory", "r14");
677 #else
678     memset(dst, val, block_size);
679 #endif
680 }
681 
682 #if defined(ARM_MATH_DSP)
683 
684 /**
685  * @brief read and expand one s8 word into two s16 words with ordering.
686  */
read_and_pad(const int8_t * source,int32_t * out1,int32_t * out2)687 __STATIC_FORCEINLINE const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
688 {
689     int32_t inA = arm_nn_read_s8x4_ia(&source);
690     int32_t inAbuf1 = SXTB16_RORn((uint32_t)inA, 8);
691     int32_t inAbuf2 = SXTB16(inA);
692 
693     #ifndef ARM_MATH_BIG_ENDIAN
694     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
695     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
696     #else
697     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
698     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
699     #endif
700 
701     return source;
702 }
703 
704 /**
705  * @brief read and expand one s8 word into two s16 words with no additional ordering.
706  */
read_and_pad_reordered(const int8_t * source,int32_t * out1,int32_t * out2)707 __STATIC_FORCEINLINE const int8_t *read_and_pad_reordered(const int8_t *source, int32_t *out1, int32_t *out2)
708 {
709     int32_t inA = arm_nn_read_s8x4_ia(&source);
710     #ifndef ARM_MATH_BIG_ENDIAN
711     *out2 = SXTB16(ROR((uint32_t)inA, 8));
712     *out1 = SXTB16(inA);
713     #else
714     *out1 = SXTB16(ROR((uint32_t)inA, 8));
715     *out2 = SXTB16(inA);
716     #endif
717 
718     return source;
719 }
720 
721 #endif
722 
723 /**
724  * @brief Matrix-multiplication function for convolution with per-channel requantization.
725  * @param[in]       input_a     pointer to operand A
726  * @param[in]       input_b     pointer to operand B, always consists of 2 vectors.
727  * @param[in]       output_ch   number of rows of A
728  * @param[in]       out_shift  pointer to per output channel requantization shift parameter.
729  * @param[in]       out_mult   pointer to per output channel requantization multiplier parameter.
730  * @param[in]       out_offset      output tensor offset.
731  * @param[in]       activation_min   minimum value to clamp the output to. Range : int8
732  * @param[in]       activation_max   maximum value to clamp the output to. Range : int8
733  * @param[in]       num_col_a   number of columns of A
734  * @param[in]       output_bias per output channel bias. Range : int32
735  * @param[in,out]   out_0       pointer to output
736  * @return     The function returns one of the two
737  *              1. The incremented output pointer for a successful operation or
738  *              2. NULL if implementation is not available.
739  *
740  * @details   This function does the matrix multiplication of weight matrix for all output channels
741  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
742  *            clamped in the range provided by activation min and max.
743  *            Supported framework: TensorFlow Lite micro.
744  */
745 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
746                                       const int16_t *input_b,
747                                       const uint16_t output_ch,
748                                       const int32_t *out_shift,
749                                       const int32_t *out_mult,
750                                       const int32_t out_offset,
751                                       const int16_t activation_min,
752                                       const int16_t activation_max,
753                                       const int32_t num_col_a,
754                                       const int32_t *const output_bias,
755                                       int8_t *out_0);
756 
757 /**
758  * @brief Common softmax function for s8 input and s8 or s16 output
759  * @param[in]  input          Pointer to the input tensor
760  * @param[in]  num_rows       Number of rows in the input tensor
761  * @param[in]  row_size       Number of elements in each input row
762  * @param[in]  mult           Input quantization multiplier
763  * @param[in]  shift          Input quantization shift within the range [0, 31]
764  * @param[in]  diff_min       Minimum difference with max in row. Used to check if
765  *                            the quantized exponential operation can be performed
766  * @param[in]  int16_output   Indicating s8 output if 0 else s16 output
767  * @param[out] output         Pointer to the output tensor
768  *
769  * @note Supported framework: TensorFlow Lite micro (bit-accurate)
770  *
771  */
772 void arm_nn_softmax_common_s8(const int8_t *input,
773                               const int32_t num_rows,
774                               const int32_t row_size,
775                               const int32_t mult,
776                               const int32_t shift,
777                               const int32_t diff_min,
778                               const bool int16_output,
779                               void *output);
780 
781 /**
782  * @brief macro for adding rounding offset
783  */
784 #ifndef ARM_NN_TRUNCATE
785     #define NN_ROUND(out_shift) ((0x1 << out_shift) >> 1)
786 #else
787     #define NN_ROUND(out_shift) 0
788 #endif
789 
790 // Macros for shortening quantization functions' names and avoid long lines
791 #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
792 #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
793 #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
794 
795 #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
796 #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
797 
798 #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
799 #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
800 
801 /**
802  * @brief           Saturating doubling high multiply. Result matches
803  *                  NEON instruction VQRDMULH.
804  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
805  * @param[in]       m2        Multiplier. Range: {NN_Q31_MIN, NN_Q31_MAX}
806  * @return          Result of multiplication.
807  *
808  */
arm_nn_doubling_high_mult(const int32_t m1,const int32_t m2)809 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult(const int32_t m1, const int32_t m2)
810 {
811     int32_t result = 0;
812     // Rounding offset to add for a right shift of 31
813     int64_t mult = 1 << 30;
814 
815     if ((m1 < 0) ^ (m2 < 0))
816     {
817         mult = 1 - mult;
818     }
819     // Gets resolved as a SMLAL instruction
820     mult = mult + (int64_t)m1 * m2;
821 
822     // Utilize all of the upper 32 bits. This is the doubling step
823     // as well.
824     result = (int32_t)(mult / (1ll << 31));
825 
826     if ((m1 == m2) && (m1 == (int32_t)NN_Q31_MIN))
827     {
828         result = NN_Q31_MAX;
829     }
830     return result;
831 }
832 
833 /**
834  * @brief           Doubling high multiply without saturation. This is intended
835  *                  for requantization where the scale is a positive integer
836  *
837  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
838  * @param[in]       m2        Multiplier Range: {NN_Q31_MIN, NN_Q31_MAX}
839  * @return          Result of multiplication.
840  * @note            The result of this matches that of neon instruction
841  *                  VQRDMULH for m1 in range {NN_Q31_MIN, NN_Q31_MAX} and m2 in
842  *                  range {NN_Q31_MIN + 1, NN_Q31_MAX}. Saturation occurs when
843  *                  m1 equals m2 equals NN_Q31_MIN and that is not handled by
844  *                  this function.
845  *
846  */
arm_nn_doubling_high_mult_no_sat(const int32_t m1,const int32_t m2)847 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult_no_sat(const int32_t m1, const int32_t m2)
848 {
849     int32_t result = 0;
850     union arm_nn_long_long mult;
851 
852     // Rounding offset to add for a right shift of 31
853     mult.word.low = 1 << 30;
854     mult.word.high = 0;
855 
856     // Gets resolved as a SMLAL instruction
857     mult.long_long = mult.long_long + (int64_t)m1 * m2;
858 
859     // Utilize all of the upper 32 bits. This is the doubling step
860     // as well.
861     result = (int32_t)(mult.long_long >> 31);
862 
863     return result;
864 }
865 
866 /**
867  * @brief           Rounding divide by power of two.
868  * @param[in]       dividend - Dividend
869  * @param[in]       exponent - Divisor = power(2, exponent)
870  *                             Range: [0, 31]
871  * @return          Rounded result of division. Midpoint is rounded away from zero.
872  *
873  */
arm_nn_divide_by_power_of_two(const int32_t dividend,const int32_t exponent)874 __STATIC_FORCEINLINE int32_t arm_nn_divide_by_power_of_two(const int32_t dividend, const int32_t exponent)
875 {
876     int32_t result = 0;
877     const int32_t remainder_mask = (1 << exponent) - 1;
878     int32_t remainder = remainder_mask & dividend;
879 
880     // Basic division
881     result = dividend >> exponent;
882 
883     // Adjust 'result' for rounding (mid point away from zero)
884     int32_t threshold = remainder_mask >> 1;
885     if (result < 0)
886     {
887         threshold++;
888     }
889     if (remainder > threshold)
890     {
891         result++;
892     }
893 
894     return result;
895 }
896 
897 /**
898  * @brief           Requantize a given value.
899  * @param[in]       val         Value to be requantized
900  * @param[in]       multiplier  multiplier. Range {NN_Q31_MIN + 1, Q32_MAX}
901  * @param[in]       shift       left or right shift for 'val * multiplier'
902  *
903  * @return          Returns (val * multiplier)/(2 ^ shift)
904  *
905  */
arm_nn_requantize(const int32_t val,const int32_t multiplier,const int32_t shift)906 __STATIC_FORCEINLINE int32_t arm_nn_requantize(const int32_t val, const int32_t multiplier, const int32_t shift)
907 {
908 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
909     const int64_t total_shift = 31 - shift;
910     const int64_t new_val = val * (int64_t)multiplier;
911 
912     int32_t result = new_val >> (total_shift - 1);
913     result = (result + 1) >> 1;
914 
915     return result;
916 #else
917     return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
918                                          RIGHT_SHIFT(shift));
919 #endif
920 }
921 
922 /**
923  * @brief           Requantize a given 64 bit value.
924  * @param[in]       val                 Value to be requantized in the range {-(1<<47)} to {(1<<47) - 1}
925  * @param[in]       reduced_multiplier  Reduced multiplier in the range {NN_Q31_MIN + 1, Q32_MAX} to {Q16_MIN + 1,
926  * Q16_MAX}
927  * @param[in]       shift               Left or right shift for 'val * multiplier' in the range {-31} to {7}
928  *
929  * @return          Returns (val * multiplier)/(2 ^ shift)
930  *
931  */
arm_nn_requantize_s64(const int64_t val,const int32_t reduced_multiplier,const int32_t shift)932 __STATIC_FORCEINLINE int32_t arm_nn_requantize_s64(const int64_t val,
933                                                    const int32_t reduced_multiplier,
934                                                    const int32_t shift)
935 {
936     const int64_t new_val = val * reduced_multiplier;
937 
938     int32_t result = new_val >> (14 - shift); // 64->32 bit reduction
939     result = (result + 1) >> 1;               // Last shift position and insert round
940 
941     return result;
942 }
943 
944 /**
945  * @brief           memcpy optimized for MVE
946  * @param[in, out]  dst         Destination pointer
947  * @param[in]       src         Source pointer.
948  * @param[in]       block_size  Number of bytes to copy.
949  *
950  */
arm_memcpy_s8(int8_t * __RESTRICT dst,const int8_t * __RESTRICT src,uint32_t block_size)951 __STATIC_FORCEINLINE void arm_memcpy_s8(int8_t *__RESTRICT dst, const int8_t *__RESTRICT src, uint32_t block_size)
952 {
953 #if defined(ARM_MATH_MVEI)
954     __asm volatile("   wlstp.8                 lr, %[cnt], 1f             \n"
955                    "2:                                                    \n"
956                    "   vldrb.8                 q0, [%[in]], #16            \n"
957                    "   vstrb.8                 q0, [%[out]], #16           \n"
958                    "   letp                    lr, 2b                     \n"
959                    "1:                                                    \n"
960                    : [in] "+r"(src), [out] "+r"(dst)
961                    : [cnt] "r"(block_size)
962                    : "q0", "memory", "r14");
963 #else
964     memcpy(dst, src, block_size);
965 #endif
966 }
967 
968 /**
969  * @brief           memcpy wrapper for int16
970  * @param[in, out]  dst         Destination pointer
971  * @param[in]       src         Source pointer.
972  * @param[in]       block_size  Number of bytes to copy.
973  *
974  */
arm_memcpy_q15(int16_t * __RESTRICT dst,const int16_t * __RESTRICT src,uint32_t block_size)975 __STATIC_FORCEINLINE void arm_memcpy_q15(int16_t *__RESTRICT dst, const int16_t *__RESTRICT src, uint32_t block_size)
976 {
977     memcpy(dst, src, block_size);
978 }
979 
980 #if defined(ARM_MATH_MVEI)
981 /**
982  * @brief           Vector saturating doubling high multiply returning high half.
983  * @param[in]       m1        Multiplicand
984  * @param[in]       m2        Multiplier
985  * @return          Result of multiplication.
986  *
987  */
arm_doubling_high_mult_mve(const int32x4_t m1,const int32_t m2)988 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const int32_t m2)
989 {
990     return vqrdmulhq_n_s32(m1, m2);
991 }
992 
993 /**
994  * @brief           Vector rounding divide by power of two.
995  * @param[in]       dividend - Dividend vector
996  * @param[in]       exponent - Divisor = power(2, exponent)
997  *                             Range: [0, 31]
998  * @return          Rounded result of division. Midpoint is rounded away from zero.
999  *
1000  */
arm_divide_by_power_of_two_mve(const int32x4_t dividend,const int32_t exponent)1001 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const int32_t exponent)
1002 {
1003     const int32x4_t shift = vdupq_n_s32(-exponent);
1004     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1005     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1006     return vrshlq_s32(fixed_up_dividend, shift);
1007 }
1008 
1009 /**
1010  * @brief           Requantize a given vector.
1011  * @param[in]       val         Vector to be requantized
1012  * @param[in]       multiplier  multiplier
1013  * @param[in]       shift       shift
1014  *
1015  * @return          Returns (val * multiplier)/(2 ^ shift)
1016  *
1017  */
arm_requantize_mve(const int32x4_t val,const int32_t multiplier,const int32_t shift)1018 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const int32_t multiplier, const int32_t shift)
1019 {
1020     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1021     const int right_shift = MIN(-1, shift);
1022     const int left_shift = shift - right_shift;
1023 
1024     const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
1025     const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
1026 
1027     int32x4_t result = vqdmulhq_n_s32(vshlq_s32(val, left_shift_dup), multiplier);
1028     result = vrshlq_s32(result, right_shift_dup);
1029 
1030     return result;
1031     #else
1032     return arm_divide_by_power_of_two_mve(
1033         arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
1034     #endif
1035 }
1036 
arm_doubling_high_mult_mve_32x4(const int32x4_t m1,const int32x4_t m2)1037 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
1038 {
1039     return vqrdmulhq_s32(m1, m2);
1040 }
1041 
arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend,const int32x4_t exponent)1042 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
1043 {
1044     const int32x4_t shift = -exponent;
1045     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1046     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1047     return vrshlq_s32(fixed_up_dividend, shift);
1048 }
1049 
arm_requantize_mve_32x4(const int32x4_t val,const int32x4_t multiplier,const int32x4_t shift)1050 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
1051                                                        const int32x4_t multiplier,
1052                                                        const int32x4_t shift)
1053 {
1054     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1055     const int32x4_t right_shift = vminq_s32(vdupq_n_s32(-1), shift);
1056     const int32x4_t left_shift = vqsubq_s32(shift, right_shift);
1057 
1058     int32x4_t result = vqdmulhq_s32(vshlq_s32(val, left_shift), multiplier);
1059     result = vrshlq_s32(result, right_shift);
1060 
1061     return result;
1062     #else
1063     const int32x4_t zz = vdupq_n_s32(0);
1064     const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
1065 
1066     const int32x4_t left_shift = vpselq_s32(shift, zz, p);
1067     const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
1068 
1069     return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
1070                                                right_shift);
1071     #endif
1072 }
1073 #endif
1074 
1075 // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
1076 
arm_nn_exp_on_negative_values(int32_t val)1077 __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
1078 {
1079     int32_t mask = 0;
1080     int32_t shift = 24;
1081 
1082     const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
1083     const int32_t remainder = val_mod_minus_quarter - val;
1084     const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
1085     const int32_t x2 = MUL_SAT(x, x);
1086 
1087     int32_t result = 1895147668 +
1088         MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
1089 
1090 #define SELECT_IF_NON_ZERO(x)                                                                                          \
1091     {                                                                                                                  \
1092         mask = MASK_IF_NON_ZERO(remainder & (1 << shift++));                                                           \
1093         result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result);                                                  \
1094     }
1095 
1096     SELECT_IF_NON_ZERO(1672461947)
1097     SELECT_IF_NON_ZERO(1302514674)
1098     SELECT_IF_NON_ZERO(790015084)
1099     SELECT_IF_NON_ZERO(290630308)
1100     SELECT_IF_NON_ZERO(39332535)
1101     SELECT_IF_NON_ZERO(720401)
1102     SELECT_IF_NON_ZERO(242)
1103 
1104 #undef SELECT_IF_NON_ZERO
1105 
1106     mask = MASK_IF_ZERO(val);
1107     return SELECT_USING_MASK(mask, NN_Q31_MAX, result);
1108 }
1109 
arm_nn_mult_by_power_of_two(const int32_t val,const int32_t exp)1110 __STATIC_FORCEINLINE int32_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
1111 {
1112     const int32_t thresh = ((1 << (31 - exp)) - 1);
1113     int32_t result = val << exp;
1114     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), NN_Q31_MAX, result);
1115     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), NN_Q31_MIN, result);
1116     return result;
1117 }
1118 
arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)1119 __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
1120 {
1121     const int64_t sum = (int64_t)val + (int64_t)NN_Q31_MAX;
1122     const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
1123     int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
1124 
1125     const int32_t shift = (1 << 29);
1126     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1127     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1128     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1129 
1130     return MUL_POW2(x, 1);
1131 }
1132 
1133 /**
1134   @brief         Write 2 s16 elements and post increment pointer.
1135   @param[in]     dest_q15  Pointer to pointer that holds address of destination.
1136   @param[in]     src_q31   Input value to be written.
1137  */
arm_nn_write_q15x2_ia(int16_t ** dest_q15,int32_t src_q31)1138 __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(int16_t **dest_q15, int32_t src_q31)
1139 {
1140     int32_t val = src_q31;
1141 
1142     memcpy(*dest_q15, &val, 4);
1143     *dest_q15 += 2;
1144 }
1145 
1146 /**
1147   @brief         Write 2 s8 elements and post increment pointer.
1148   @param[in]     dst  Pointer to pointer that holds address of destination.
1149   @param[in]     src  Input value to be written.
1150  */
arm_nn_write_s8x2_ia(int8_t ** dst,int16_t src)1151 __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)
1152 {
1153     memcpy(*dst, &src, 2);
1154     *dst += 2;
1155 }
1156 
1157 // Support functions for LSTM
1158 /**
1159  * @brief Update LSTM function for an iteration step
1160  *
1161  * param[in]    input                           Input data
1162  * param[in]    input_to_input_weight           Input to input gate weights
1163  * param[in]    input_to_forget_weight          Input to forget gate weights
1164  * param[in]    input_to_cell_weight            Input to cell gate weights
1165  * param[in]    input_to_output_weight          Input to output weights
1166  * param[in]    recurrent_to_input_weight       Recurrent signal to input weights
1167  * param[in]    recurrent_to_forget_weight      Recurrent signal to forget gate weights
1168  * param[in]    recurrent_to_cell_weight        Recurrent signal to cell gate weighst
1169  * param[in]    recurrent_to_output_weight      Recurrent signal to output weights
1170  * param[in]    lstm                            LSTM parameters
1171  * param[in]    n_batch                         Batch size
1172  * param[in]    n_cell                          Cell size
1173  * param[in]    n_input                         Input size
1174  * param[in]    n_output                        Output size
1175  * param[out]   output_state                    Output state
1176  * param[out]   cell_state                      Internal state
1177  * param[out]   output                          Output signal
1178  * param[in] *scratch_buffers                   Struct containing scratch buffers
1179  */
1180 arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
1181                                             const int8_t *input_to_input_weight,
1182                                             const int8_t *input_to_forget_weight,
1183                                             const int8_t *input_to_cell_weight,
1184                                             const int8_t *input_to_output_weight,
1185                                             const int8_t *recurrent_to_input_weight,
1186                                             const int8_t *recurrent_to_forget_weight,
1187                                             const int8_t *recurrent_to_cell_weight,
1188                                             const int8_t *recurrent_to_output_weight,
1189                                             const cmsis_nn_lstm_params *lstm,
1190                                             const int n_batch,
1191                                             const int n_cell,
1192                                             const int n_input,
1193                                             const int n_output,
1194                                             int8_t *output_state,
1195                                             int16_t *cell_state,
1196                                             int8_t *output,
1197                                             cmsis_nn_lstm_context *scratch_buffers);
1198 
1199 /**
1200  * @brief         Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
1201  *
1202  * param[in]    input                           Input data
1203  * param[in]    input_to_gate_weights           Input to gate weights
1204  * param[in]    input_to_gate_bias              Input to gate weights
1205  * param[in]    input_to_gate_scaling           Input to gate scaling
1206  * param[in]    activation                      Actival min and max values
1207  * param[in]    output_state                    Output state
1208  * param[in]    recurrent_to_gate_weights       Recurrent to gate weights
1209  * param[in]    recurrent_to_gate_bias          Recurrent to gate bias
1210  * param[in]    recurrent_to_gate_scaling       Recurrent to gate scaling
1211  * param[in]    n_batch                         Batch size
1212  * param[in]    n_input                         Input size
1213  * param[out]   n_output                        Output size
1214  * param[in]    activation_type                 Activation type (sigmoid or tanh)
1215  * param[out]   n_cell                          Cell size
1216  */
1217 void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
1218                                        const int8_t *input_to_gate_weights,
1219                                        const int32_t *input_to_gate_bias,
1220                                        const cmsis_nn_scaling input_to_gate_scaling,
1221                                        const int8_t *output_state,
1222                                        const int8_t *recurrent_to_gate_weights,
1223                                        const int32_t *recurrent_to_gate_bias,
1224                                        const cmsis_nn_scaling recurrent_to_gate_scaling,
1225                                        const int32_t n_batch,
1226                                        const int32_t n_input,
1227                                        const int32_t n_output,
1228                                        const int32_t n_cell,
1229                                        const arm_nn_activation_type activation_type,
1230                                        int16_t *gate);
1231 
1232 /**
1233  * @brief       Update cell state for a single LSTM iteration step, int8x8_16 version.
1234  * @param[in]   n_block             total number of cells for all batches
1235  * @param[in]   cell_state_scale    Scaling factor of cell state
1236  * @param[in]   cell_state          Input/output vector, size n_batch*n_cell
1237  * @param[in]   input_gate          Input vector of size n_block
1238  * @param[in]   forget_gate         Input/scratch vector of size n_block, always modified
1239  * @param[in]   cell_gate           Input vector of size, n_block
1240  */
1241 void arm_nn_lstm_update_cell_state_s16(const int32_t n_block,
1242                                        const int32_t cell_state_scale,
1243                                        int16_t *cell_state,
1244                                        const int16_t *input_gate,
1245                                        const int16_t *forget_gate,
1246                                        const int16_t *cell_gate);
1247 
1248 /**
1249  * @brief       Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
1250  *
1251  * @param[in]       n_batch                     The number of distinct vectors in each array
1252  * @param[in]       n_cell                      Number of cells
1253  * @param[in,out]   cell_state                  Cell state, size n_batch*n_cell
1254  * @param[in]       cell_state_scale            Scaling of cell_state
1255  * @param[in]       output_gate                 Output gate
1256  * @param[in]       hidden_scale                Effective scaling of cell_state .* output_gate
1257  * @param[in]       hidden_offset               Zero point for cell_state .* output_gate
1258  * @param[out]      output_state                Output state
1259  * @param[in]       cell_gate_scratch           Scratch buffer
1260  */
1261 void arm_nn_lstm_update_output_s8_s16(const int n_batch,
1262                                       const int n_cell,
1263                                       int16_t *cell_state,
1264                                       const int32_t cell_state_scale,
1265                                       const int16_t *output_gate,
1266                                       const cmsis_nn_scaling hidden_scale,
1267                                       const int32_t hidden_offset,
1268                                       int8_t *output_state,
1269                                       int16_t *cell_gate_scratch);
1270 
1271 /**
1272  * @brief The result of the multiplication is accumulated to the passed result buffer.
1273  * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
1274  * from each other).
1275  *
1276  * @param[in]   lhs_in           Batched vector
1277  * @param[in]   rhs_in           Weights - input matrix (H(Rows)xW(Columns))
1278  * @param[in]   bias             Bias vector
1279  * @param[out]  dst              Output
1280  * @param[in]   dst_offset       Output offset
1281  * @param[in]   dst_multiplier   Multiplier for quantization
1282  * @param[in]   dst_shift        Shift for quantization
1283  * @param[in]   rhs_cols         Vector/matarix column length
1284  * @param[in]   rhs_rows         Row count of matrix
1285  * @param[in]   batch            Batch size
1286  */
1287 void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
1288                                       const int8_t *rhs_in,
1289                                       const int32_t *bias,
1290                                       int16_t *dst,
1291                                       const int32_t dst_offset,
1292                                       const int32_t dst_multiplier,
1293                                       const int32_t dst_shift,
1294                                       const int32_t rhs_cols,
1295                                       const int32_t rhs_rows,
1296                                       const int32_t batch);
1297 
1298 /**
1299  * @brief s16 elementwise multiplication with s8 output
1300  * @param[in]       input_1_vect        pointer to input vector 1
1301  * @param[in]       input_2_vect        pointer to input vector 2
1302  * @param[in,out]   output              pointer to output vector
1303  * @param[in]       out_offset          output offset
1304  * @param[in]       out_mult            output multiplier
1305  * @param[in]       out_shift           output shift
1306  * @param[in]       block_size          number of samples
1307  * @return          The function returns ARM_CMSIS_NN_SUCCESS
1308  *
1309  * @details   Supported framework: TensorFlow Lite micro
1310  */
1311 arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
1312                                                const int16_t *input_2_vect,
1313                                                int8_t *output,
1314                                                const int32_t out_offset,
1315                                                const int32_t out_mult,
1316                                                const int32_t out_shift,
1317                                                const int32_t block_size);
1318 
1319 #ifdef __cplusplus
1320 }
1321 #endif
1322 
1323 #endif
1324