1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nnsupportfunctions.h
22  * Description:  Public header file of support functions for CMSIS NN Library
23  *
24  * $Date:        08 Nov 2024
25  * $Revision:    V.22.7.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #ifndef ARM_NNSUPPORTFUNCTIONS_H
31 #define ARM_NNSUPPORTFUNCTIONS_H
32 
33 #include "Internal/arm_nn_compiler.h"
34 #include "arm_nn_math_types.h"
35 #include "arm_nn_types.h"
36 
37 #include <stdbool.h>
38 
39 #ifdef __cplusplus
40 extern "C" {
41 #endif
42 
43 #define USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims)                                         \
44     (dw_conv_params->ch_mult == 1 && dw_conv_params->dilation.w == 1 && dw_conv_params->dilation.h == 1 &&             \
45      filter_dims->w * filter_dims->h < 512)
46 
47 #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
48 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
49 #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
50 #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
51 #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
52 
53 #define MAX(A, B) ((A) > (B) ? (A) : (B))
54 #define MIN(A, B) ((A) < (B) ? (A) : (B))
55 #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
56 #define REDUCE_MULTIPLIER(_mult) ((_mult < 0x7FFF0000) ? ((_mult + (1 << 15)) >> 16) : 0x7FFF)
57 
58 // Number of channels processed in a block for DW Conv with Int8 weights(MVE)
59 // Requirement: Greater than 0 & less than 128
60 // This can be fine tuned to match number of input channels for best performance.
61 // A layer with lower number of channels than CH_IN_BLOCK_MVE will result in higher
62 // scratch buffer usage and a layer with higher number of channels than CH_IN_BLOCK_MVE
63 // will result in lower scratch buffer usage.
64 #define CH_IN_BLOCK_MVE (124)
65 
66 // Number of channels processed in a block for DW Conv with Int4 weights(MVE)
67 // Requirement: See CH_IN_BLOCK_MVE.
68 // An additional requirement for this signed 4 variant is that it must be an even number.
69 #define S4_CH_IN_BLOCK_MVE (124)
70 
71 // For input of int16 when number of columns are above this limit int64 accumulation is needed
72 // to not loose precision.
73 #define MAX_COL_COUNT (512)
74 
75 // CMSIS-NN has two implementations of the transpose conv operator, selected depending on the number of input
76 // channels. This is based on heuristics and may be finetuned depending on other parameters of the operator
77 #define REVERSE_TCOL_EFFICIENT_THRESHOLD (16)
78 
79 // Threshold for number of output channels that decide whether to convert a depthwise conv to a
80 // regular conv operation when number of input channels is one.
81 // Only applicable for processors with MVE extension.
82 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
83     #define CONVERT_DW_CONV_WITH_ONE_INPUT_CH_AND_OUTPUT_CH_ABOVE_THRESHOLD (8)
84 #else
85     #define CONVERT_DW_CONV_WITH_ONE_INPUT_CH_AND_OUTPUT_CH_ABOVE_THRESHOLD (1)
86 #endif
87 
88 // By default this will have no effect. During compilation this may be set to __restrict,
89 // which may be beneficial for performance. See README.md for more intformation.
90 #ifndef OPTIONAL_RESTRICT_KEYWORD
91     #define OPTIONAL_RESTRICT_KEYWORD
92 #endif
93 
94 /**
95  * @brief definition to pack four 8 bit values.
96  */
97 #define PACK_S8x4_32x1(v0, v1, v2, v3)                                                                                 \
98     ((((int32_t)(v0) << 0) & (int32_t)0x000000FF) | (((int32_t)(v1) << 8) & (int32_t)0x0000FF00) |                     \
99      (((int32_t)(v2) << 16) & (int32_t)0x00FF0000) | (((int32_t)(v3) << 24) & (int32_t)0xFF000000))
100 
101 /**
102  * @brief definition to pack two 16 bit values.
103  */
104 #define PACK_Q15x2_32x1(v0, v1) (((int32_t)v0 & (int32_t)0xFFFF) | ((int32_t)v1 << 16))
105 
106 /**
107  * @defgroup groupSupport Private
108  *
109  * Internal Support functions. Not intended to be called direclty by a CMSIS-NN user.
110  *
111  */
112 
113 /**
114  * @defgroup genPrivTypes Structure Types
115  * @ingroup groupSupport
116  * @brief Data structure types used by private functions.
117  * @{
118  */
119 
120 /**
121  * @brief Union for SIMD access of q31/s16/s8 types
122  */
123 union arm_nnword
124 {
125     int32_t word;
126     /**< q31 type */
127     int16_t half_words[2];
128     /**< s16 type */
129     int8_t bytes[4];
130     /**< s8 type */
131 };
132 
133 /**
134  * @brief Union for data type long long
135  */
136 struct arm_nn_double
137 {
138     uint32_t low;
139     int32_t high;
140 };
141 
142 union arm_nn_long_long
143 {
144     int64_t long_long;
145     struct arm_nn_double word;
146 };
147 
148 /**
149  * @} // end group groupPrivTypes
150  */
151 
152 /**
153  * @defgroup supportConversion Data Conversion
154  *
155  * Perform data type conversion in-between neural network operations
156  *
157  */
158 
159 /**
160  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
161  * @param[in]    src        pointer to the s8 input vector
162  * @param[out]   dst        pointer to the s16 output vector
163  * @param[in]    block_size length of the input vector
164  * @param[in]    offset     s16 offset to be added to each input vector element.
165  *
166  * \par Description:
167  *
168  * Output elements are ordered.
169  * The equation used for the conversion process is:
170  *
171  * <pre>
172  *  dst[n] = (int16_t) src[n] + offset;   0 <= n < block_size.
173  * </pre>
174  *
175  */
176 void arm_q7_to_q15_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
177 
178 #if defined(ARM_MATH_DSP)
179 /**
180  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
181  * @param[in]    src        pointer to the s8 input vector
182  * @param[out]   dst        pointer to the s16 output vector
183  * @param[in]    block_size length of the input vector
184  * @param[in]    offset     s16 offset to be added to each input vector element.
185  *
186  * \par Description:
187  *
188  * No additonal ordering is done with the result that output elements are not in order.
189  * Instead of ABCD order will be ACBD.
190  * Note this is for processors with DSP extension only.
191  * The equation used for the conversion process is:
192  *
193  * <pre>
194  *  dst[n - 0] = (int16_t) src[n - 0] + offset;   0 <= n < block_size.
195  *  dst[n - 1] = (int16_t) src[n - 2] + offset;   0 <= n < block_size.
196  *  dst[n - 2] = (int16_t) src[n - 1] + offset;   0 <= n < block_size.
197  *  dst[n - 3] = (int16_t) src[n - 3] + offset;   0 <= n < block_size.
198  * </pre>
199  *
200  */
201 void arm_s8_to_s16_unordered_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
202 
203 #endif
204 
205 /**
206  * @brief Get the required buffer size for optimized s8 depthwise convolution
207  *        function with constraint that in_channel equals out_channel.
208  *        This is for processors with MVE extension.
209  *        Refer to arm_depthwise_conv_s8_opt_get_buffer_size() for function argument details.
210  *
211  * @note  Intended for compilation on Host. If compiling for an Arm target, use
212  *        arm_depthwise_conv_s8_opt_get_buffer_size(). Note also this is a support function,
213  *        so not recommended to call directly even on Host.
214  *
215  */
216 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
217                                                       const cmsis_nn_dims *filter_dims);
218 
219 /**
220  * @brief Get the required buffer size for optimized s8 depthwise convolution
221  *        function with constraint that in_channel equals out_channel.
222  *        This is for processors with DSP extension.
223  *        Refer to arm_depthwise_conv_s8_opt_get_buffer_size() for function argument details.
224  *
225  * @note  Intended for compilation on Host. If compiling for an Arm target, use
226  *        arm_depthwise_conv_s8_opt_get_buffer_size(). Note also this is a support function,
227  *        so not recommended to call directly even on Host.
228  *
229  */
230 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_dsp(const cmsis_nn_dims *input_dims,
231                                                       const cmsis_nn_dims *filter_dims);
232 
233 /**
234  * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
235  * @param[in]    row     pointer to row
236  * @param[in]    col     pointer to im2col buffer, always consists of 2 columns.
237  * @param[in]    num_ch   number of channels
238  * @param[in]    out_shift  pointer to per output channel requantization shift parameter.
239  * @param[in]    out_mult   pointer to per output channel requantization multiplier parameter.
240  * @param[in]    out_offset      output tensor offset.
241  * @param[in]    activation_min   minimum value to clamp the output to. Range : int8
242  * @param[in]    activation_max   maximum value to clamp the output to. Range : int8
243  * @param[in]    kernel_size   number of elements in one column.
244  * @param[in]    output_bias per output channel bias. Range : int32
245  * @param[out]   out         pointer to output
246  * @return     The function returns one of the two
247  *              1. The incremented output pointer for a successful operation or
248  *              2. NULL if implementation is not available.
249  *
250  * @details     Supported framework: TensorFlow Lite micro.
251  */
252 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
253                                       const int16_t *col,
254                                       const uint16_t num_ch,
255                                       const int32_t *out_shift,
256                                       const int32_t *out_mult,
257                                       const int32_t out_offset,
258                                       const int32_t activation_min,
259                                       const int32_t activation_max,
260                                       const uint16_t kernel_size,
261                                       const int32_t *const output_bias,
262                                       int8_t *out);
263 
264 /**
265  * @brief General Matrix-multiplication function with per-channel requantization.
266  * @param[in]       input_row    pointer to row operand
267  * @param[in]       input_col    pointer to col operand
268  * @param[in]       output_ch    number of rows of input_row
269  * @param[in]       col_batches  number of column batches. Range: 1 to 4
270  * @param[in]       output_shift  pointer to per output channel requantization shift parameter.
271  * @param[in]       output_mult   pointer to per output channel requantization multiplier parameter.
272  * @param[in]       out_offset    output tensor offset.
273  * @param[in]       col_offset    input tensor(col) offset.
274  * @param[in]       row_offset    kernel offset(row). Not used.
275  * @param[in]       out_activation_min   minimum value to clamp the output to. Range : int8
276  * @param[in]       out_activation_max   maximum value to clamp the output to. Range : int8
277  * @param[in]       row_len       number of elements in each row
278  * @param[in]       bias          per output channel bias. Range : int32
279  * @param[in,out]   out           pointer to output
280  * @return     The function returns one of the two
281  *              1. The incremented output pointer for a successful operation or
282  *              2. NULL if implementation is not available.
283  *
284  * @details   Supported framework: TensorFlow Lite
285  */
286 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
287                            const int8_t *input_col,
288                            const uint16_t output_ch,
289                            const uint16_t col_batches,
290                            const int32_t *output_shift,
291                            const int32_t *output_mult,
292                            const int32_t out_offset,
293                            const int32_t col_offset,
294                            const int32_t row_offset,
295                            const int16_t out_activation_min,
296                            const int16_t out_activation_max,
297                            const uint16_t row_len,
298                            const int32_t *const bias,
299                            int8_t *out);
300 /**
301  * @brief Matrix-multiplication function for convolution with per-channel requantization for 16 bits convolution.
302  * @param[in]       input_a     pointer to operand A
303  * @param[in]       input_b     pointer to operand B, always consists of 2 vectors.
304  * @param[in]       output_ch   number of rows of A
305  * @param[in]       out_shift   pointer to per output channel requantization shift parameter.
306  * @param[in]       out_mult    pointer to per output channel requantization multiplier parameter.
307  * @param[in]       activation_min   minimum value to clamp the output to. Range : int16
308  * @param[in]       activation_max   maximum value to clamp the output to. Range : int16
309  * @param[in]       num_col_a   number of columns of A
310  * @param[in]       bias_data   pointer to struct with bias vector. The length of this vector is equal to the number
311  *                              of output columns (or RHS input rows). The vector can be int32 or int64 indicated by a
312  *                              flag in the struct.
313  * @param[in,out]   out_0       pointer to output
314  * @return     The function returns one of the two
315  *              1. The incremented output pointer for a successful operation or
316  *              2. NULL if implementation is not available.
317  *
318  * @details   This function does the matrix multiplication of weight matrix for all output channels
319  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
320  *            clamped in the range provided by activation min and max.
321  *            Supported framework: TensorFlow Lite micro.
322  */
323 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
324                                     const int16_t *input_b,
325                                     const int32_t output_ch,
326                                     const int32_t *out_shift,
327                                     const int32_t *out_mult,
328                                     const int32_t activation_min,
329                                     const int32_t activation_max,
330                                     const int32_t num_col_a,
331                                     const cmsis_nn_bias_data *const bias_data,
332                                     int16_t *out_0);
333 
334 /**
335  * @brief General Vector by Matrix multiplication with requantization and storage of result.
336  * @param[in]       row_elements          number of row elements
337  * @param[in]       skipped_row_elements  number of row elements skipped due to padding.
338  *                                        row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
339  * @param[in]       row_base_ref          pointer to row operand
340  * @param[in]       col_base_ref          pointer to col operand
341  * @param[out]      out_ch                Number of output channels
342  * @param[in]       conv_params           Pointer to convolution parameters like offsets and activation values
343  * @param[in]       quant_params          Pointer to per-channel quantization parameters
344  * @param[in]       bias                  Pointer to optional per-channel bias
345  * @param[out]      output                Pointer to output where int8 results are stored.
346  * @return     The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
347  *             scaled result is stored in memory.
348  *
349  * @details Pseudo-code
350  *      *output = 0
351  *      sum_col = 0
352  *      for (j = 0; j < out_ch; j++)
353  *      for (i = 0; i < row_elements; i++)
354  *          *output += row_base_ref[i] * col_base_ref[i]
355  *          sum_col += col_base_ref[i]
356  *      scale sum_col using quant_params and bias
357  *      store result in 'output'
358  *
359  *
360  */
361 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
362                                               const int32_t skipped_row_elements,
363                                               const int8_t *row_base_ref,
364                                               const int8_t *col_base_ref,
365                                               const int32_t out_ch,
366                                               const cmsis_nn_conv_params *conv_params,
367                                               const cmsis_nn_per_channel_quant_params *quant_params,
368                                               const int32_t *bias,
369                                               int8_t *output);
370 
371 /**
372  * @brief General Vector by Matrix multiplication with requantization, storage of result and int4 weights packed into an
373  * int8 buffer.
374  * @param[in]       row_elements          number of row elements
375  * @param[in]       skipped_row_elements  number of row elements skipped due to padding.
376  *                                        row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
377  * @param[in]       row_base_ref          pointer to row operand
378  * @param[in]       col_base_ref          pointer to col operand as packed int4
379  * @param[out]      out_ch                Number of output channels
380  * @param[in]       conv_params           Pointer to convolution parameters like offsets and activation values
381  * @param[in]       quant_params          Pointer to per-channel quantization parameters
382  * @param[in]       bias                  Pointer to optional per-channel bias
383  * @param[out]      output                Pointer to output where int8 results are stored.
384  * @return     The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
385  *             scaled result is stored in memory.
386  *
387  * @details Pseudo-code as int8 example. Int4 filter data will be unpacked.
388  *      *output = 0
389  *      sum_col = 0
390  *      for (j = 0; j < out_ch; j++)
391  *      for (i = 0; i < row_elements; i++)
392  *          *output += row_base_ref[i] * col_base_ref[i]
393  *          sum_col += col_base_ref[i]
394  *      scale sum_col using quant_params and bias
395  *      store result in 'output'
396  *
397  *
398  */
399 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s4(int32_t row_elements,
400                                               const int32_t skipped_row_elements,
401                                               const int8_t *row_base_ref,
402                                               const int8_t *col_base_ref,
403                                               const int32_t out_ch,
404                                               const cmsis_nn_conv_params *conv_params,
405                                               const cmsis_nn_per_channel_quant_params *quant_params,
406                                               const int32_t *bias,
407                                               int8_t *output);
408 
409 /**
410  * @brief Matrix-multiplication with requantization & activation function for four rows and one column
411  * @param[in]       row_elements  number of row elements
412  * @param[in]       offset        offset between rows. Can be the same as row_elements.
413  *                                For e.g, in a 1x1 conv scenario with stride as 1.
414  * @param[in]       row_base      pointer to row operand
415  * @param[in]       col_base      pointer to col operand
416  * @param[in]       out_ch        Number of output channels
417  * @param[in]       conv_params   Pointer to convolution parameters like offsets and activation values
418  * @param[in]       quant_params  Pointer to per-channel quantization parameters
419  * @param[in]       bias          Pointer to per-channel bias
420  * @param[out]      output        Pointer to output where int8 results are stored.
421  *
422  * @return     The function returns the updated output pointer or NULL if implementation is not available.
423  *
424  * @details Compliant to TFLM int8 specification. MVE implementation only
425  */
426 int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
427                                   const int32_t offset,
428                                   const int8_t *row_base,
429                                   const int8_t *col_base,
430                                   const int32_t out_ch,
431                                   const cmsis_nn_conv_params *conv_params,
432                                   const cmsis_nn_per_channel_quant_params *quant_params,
433                                   const int32_t *bias,
434                                   int8_t *output);
435 
436 /**
437  * @brief General Matrix-multiplication function with per-channel requantization.
438  *        This function assumes:
439  *        - LHS input matrix NOT transposed (nt)
440  *        - RHS input matrix transposed (t)
441  *        - RHS is int8 packed with 2x int4
442  *        - LHS is int8
443  *
444  *  @note This operation also performs the broadcast bias addition before the requantization
445  *
446  * @param[in]  lhs                Pointer to the LHS input matrix
447  * @param[in]  rhs                Pointer to the RHS input matrix
448  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
449  *                                output columns (or RHS input rows)
450  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
451  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
452  *                                The length of this vector is equal to the number of output columns (or RHS input
453  *                                rows)
454  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
455  *                                of this vector is equal to the number of output columns (or RHS input rows)
456  * @param[in]  lhs_rows           Number of LHS input rows
457  * @param[in]  rhs_rows           Number of RHS input rows
458  * @param[in]  rhs_cols           Number of LHS/RHS input columns
459  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
460  * @param[in]  dst_offset         Offset to be applied the output result
461  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
462  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
463  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
464  *
465  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
466  *
467  */
468 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
469                                             const int8_t *rhs,
470                                             const int32_t *bias,
471                                             int8_t *dst,
472                                             const int32_t *dst_multipliers,
473                                             const int32_t *dst_shifts,
474                                             const int32_t lhs_rows,
475                                             const int32_t rhs_rows,
476                                             const int32_t rhs_cols,
477                                             const int32_t lhs_offset,
478                                             const int32_t dst_offset,
479                                             const int32_t activation_min,
480                                             const int32_t activation_max,
481                                             const int32_t lhs_cols_offset);
482 
483 /**
484  * @brief General Matrix-multiplication function with per-channel requantization.
485  *        This function assumes:
486  *        - LHS input matrix NOT transposed (nt)
487  *        - RHS input matrix transposed (t)
488  *        - RHS is int8 packed with 2x int4
489  *        - LHS is int8
490  *        - LHS/RHS input columns must be even numbered
491  *        - LHS must be interleaved. Compare to arm_nn_mat_mult_nt_t_s4 where LHS is not interleaved.
492  *
493  *  @note This operation also performs the broadcast bias addition before the requantization
494  *
495  * @param[in]  lhs                Pointer to the LHS input matrix
496  * @param[in]  rhs                Pointer to the RHS input matrix
497  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
498  *                                output columns (or RHS input rows)
499  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
500  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
501  *                                The length of this vector is equal to the number of output columns (or RHS input
502  *                                rows)
503  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
504  *                                of this vector is equal to the number of output columns (or RHS input rows)
505  * @param[in]  lhs_rows           Number of LHS input rows
506  * @param[in]  rhs_rows           Number of RHS input rows
507  * @param[in]  rhs_cols           Number of LHS/RHS input columns. Note this must be even.
508  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
509  * @param[in]  dst_offset         Offset to be applied the output result
510  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
511  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
512  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
513  *
514  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
515  *
516  */
517 arm_cmsis_nn_status arm_nn_mat_mult_nt_interleaved_t_even_s4(const int8_t *lhs,
518                                                              const int8_t *rhs,
519                                                              const int32_t *bias,
520                                                              int8_t *dst,
521                                                              const int32_t *dst_multipliers,
522                                                              const int32_t *dst_shifts,
523                                                              const int32_t lhs_rows,
524                                                              const int32_t rhs_rows,
525                                                              const int32_t rhs_cols,
526                                                              const int32_t lhs_offset,
527                                                              const int32_t dst_offset,
528                                                              const int32_t activation_min,
529                                                              const int32_t activation_max,
530                                                              const int32_t lhs_cols_offset);
531 
532 /**
533  * @brief General Matrix-multiplication function with per-channel requantization.
534  *        This function assumes:
535  *        - LHS input matrix NOT transposed (nt)
536  *        - RHS input matrix transposed (t)
537  *
538  *  @note This operation also performs the broadcast bias addition before the requantization
539  *
540  * @param[in]  lhs                Pointer to the LHS input matrix
541  * @param[in]  rhs                Pointer to the RHS input matrix
542  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
543  *                                output columns (or RHS input rows)
544  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
545  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
546  *                                The length of this vector is equal to the number of output columns (or RHS input
547  *                                rows)
548  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
549  *                                of this vector is equal to the number of output columns (or RHS input rows)
550  * @param[in]  lhs_rows           Number of LHS input rows
551  * @param[in]  rhs_rows           Number of RHS input rows
552  * @param[in]  rhs_cols           Number of LHS/RHS input columns
553  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
554  * @param[in]  dst_offset         Offset to be applied the output result
555  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
556  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
557  * @param[in]  row_address_offset Address offset between rows in output. NOTE: Only used for MVEI extension.
558  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
559  *
560  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
561  *
562  */
563 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
564                                             const int8_t *rhs,
565                                             const int32_t *bias,
566                                             int8_t *dst,
567                                             const int32_t *dst_multipliers,
568                                             const int32_t *dst_shifts,
569                                             const int32_t lhs_rows,
570                                             const int32_t rhs_rows,
571                                             const int32_t rhs_cols,
572                                             const int32_t lhs_offset,
573                                             const int32_t dst_offset,
574                                             const int32_t activation_min,
575                                             const int32_t activation_max,
576                                             const int32_t row_address_offset,
577                                             const int32_t lhs_cols_offset);
578 
579 /**
580  * @brief General Matrix-multiplication function with per-channel requantization and int16 input (LHS) and output.
581  *        This function assumes:
582  *        - LHS input matrix NOT transposed (nt)
583  *        - RHS input matrix transposed (t)
584  *
585  *  @note This operation also performs the broadcast bias addition before the requantization
586  *
587  * @param[in]  lhs                Pointer to the LHS input matrix
588  * @param[in]  rhs                Pointer to the RHS input matrix
589  * @param[in]  bias_data          Pointer to struct with bias vector. The length of this vector is equal to the number
590  *                                of output columns (or RHS input rows). The vector can be int32 or int64 indicated by a
591  *                                flag in the struct.
592  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
593  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
594  *                                The length of this vector is equal to the number of output columns (or RHS input
595  *                                rows)
596  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
597  *                                of this vector is equal to the number of output columns (or RHS input rows)
598  * @param[in]  lhs_rows           Number of LHS input rows
599  * @param[in]  rhs_rows           Number of RHS input rows
600  * @param[in]  rhs_cols           Number of LHS/RHS input columns
601  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int16
602  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int16
603  *
604  * @details MVE implementation only.
605  *
606  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code> or
607  *                                  <code>ARM_CMSIS_NN_NO_IMPL_ERROR</code> if not for MVE
608  *
609  */
610 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s16(const int16_t *lhs,
611                                              const int8_t *rhs,
612                                              const cmsis_nn_bias_data *bias_data,
613                                              int16_t *dst,
614                                              const int32_t *dst_multipliers,
615                                              const int32_t *dst_shifts,
616                                              const int32_t lhs_rows,
617                                              const int32_t rhs_rows,
618                                              const int32_t rhs_cols,
619                                              const int32_t activation_min,
620                                              const int32_t activation_max);
621 
622 /**
623  * @brief General Matrix-multiplication function with int8 input and int32 output.
624  *        This function assumes:
625  *        - LHS input matrix NOT transposed (nt)
626  *        - RHS input matrix transposed (t)
627  *
628  * @note  Dst/output buffer must be zeroed out before calling this function.
629  *
630  * @param[in]  lhs                Pointer to the LHS input matrix
631  * @param[in]  rhs                Pointer to the RHS input matrix
632  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
633  * @param[in]  lhs_rows           Number of LHS input rows
634  * @param[in]  rhs_rows           Number of LHS input columns/RHS input rows
635  * @param[in]  rhs_cols           Number of RHS input columns
636  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
637  * @param[in]  dst_idx_offset     Offset between subsequent output results
638  *
639  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
640  *
641  */
642 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs,
643                                                 const int8_t *rhs,
644                                                 int32_t *dst,
645                                                 const int32_t lhs_rows,
646                                                 const int32_t rhs_rows,
647                                                 const int32_t rhs_cols,
648                                                 const int32_t lhs_offset,
649                                                 const int32_t dst_idx_offset);
650 
651 /**
652  * @brief s4 Vector by Matrix (transposed) multiplication
653  *
654  * @param[in]      lhs             Input left-hand side vector
655  * @param[in]      packed_rhs      Input right-hand side matrix (transposed)
656  * @param[in]      bias            Input bias
657  * @param[out]     dst             Output vector
658  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
659  *                                 Range: -127 to 128
660  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
661  * @param[in]      dst_multiplier  Output multiplier
662  * @param[in]      dst_shift       Output shift
663  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
664  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
665  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
666  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
667  *
668  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
669  *
670  */
671 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
672                                              const int8_t *packed_rhs,
673                                              const int32_t *bias,
674                                              int8_t *dst,
675                                              const int32_t lhs_offset,
676                                              const int32_t dst_offset,
677                                              const int32_t dst_multiplier,
678                                              const int32_t dst_shift,
679                                              const int32_t rhs_cols,
680                                              const int32_t rhs_rows,
681                                              const int32_t activation_min,
682                                              const int32_t activation_max);
683 
684 /**
685  * @brief s8 Vector by Matrix (transposed) multiplication
686  *
687  * @param[in]      lhs             Input left-hand side vector
688  * @param[in]      rhs             Input right-hand side matrix (transposed)
689  * @param[in]      kernel_sum      Kernel sums of the kernels (rhs). See arm_vector_sum_s8 for more info.
690  * @param[in]      bias            Input bias
691  * @param[out]     dst             Output vector
692  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
693  *                                 Range: -127 to 128
694  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
695  * @param[in]      dst_multiplier  Output multiplier
696  * @param[in]      dst_shift       Output shift
697  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
698  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
699  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
700  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
701  * @param[in]      address_offset  Memory position offset for dst. First output is stored at 'dst', the
702  *                                 second at 'dst + address_offset' and so on. Default value is typically 1.
703  * @param[in]      rhs_offset      Offset to be added to the input values of the right-hand side vector.
704  *                                 Range: -127 to 128
705  *
706  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
707  *
708  */
709 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
710                                              const int8_t *rhs,
711                                              const int32_t *kernel_sum,
712                                              const int32_t *bias,
713                                              int8_t *dst,
714                                              const int32_t lhs_offset,
715                                              const int32_t dst_offset,
716                                              const int32_t dst_multiplier,
717                                              const int32_t dst_shift,
718                                              const int32_t rhs_cols,
719                                              const int32_t rhs_rows,
720                                              const int32_t activation_min,
721                                              const int32_t activation_max,
722                                              const int32_t address_offset,
723                                              const int32_t rhs_offset);
724 
725 /**
726  * @brief s8 Vector by Matrix (transposed) multiplication using per channel quantization for output
727  *
728  * @param[in]      lhs             Input left-hand side vector
729  * @param[in]      rhs             Input right-hand side matrix (transposed)
730  * @param[in]      kernel_sum      Kernel sums of the kernels (rhs). See arm_vector_sum_s8 for more info.
731  * @param[in]      bias            Input bias
732  * @param[out]     dst             Output vector
733  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
734  *                                 Range: -127 to 128
735  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
736  * @param[in]      dst_multiplier  Output multipliers
737  * @param[in]      dst_shift       Output shifts
738  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
739  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
740  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
741  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
742  * @param[in]      address_offset  Memory position offset for dst. First output is stored at 'dst', the
743  *                                 second at 'dst + address_offset' and so on. Default value is typically 1.
744  * @param[in]      rhs_offset      Offset to be added to the input values of the right-hand side vector.
745  *                                 Range: -127 to 128
746  *
747  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
748  *
749  */
750 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
751                                                     const int8_t *rhs,
752                                                     const int32_t *kernel_sum,
753                                                     const int32_t *bias,
754                                                     int8_t *dst,
755                                                     const int32_t lhs_offset,
756                                                     const int32_t dst_offset,
757                                                     const int32_t *dst_multiplier,
758                                                     const int32_t *dst_shift,
759                                                     const int32_t rhs_cols,
760                                                     const int32_t rhs_rows,
761                                                     const int32_t activation_min,
762                                                     const int32_t activation_max,
763                                                     const int32_t address_offset,
764                                                     const int32_t rhs_offset);
765 
766 /**
767  * @brief s16 Vector by s8 Matrix (transposed) multiplication
768  *
769  * @param[in]      lhs             Input left-hand side vector
770  * @param[in]      rhs             Input right-hand side matrix (transposed)
771  * @param[in]      bias            Input bias
772  * @param[out]     dst             Output vector
773  * @param[in]      dst_multiplier  Output multiplier
774  * @param[in]      dst_shift       Output shift
775  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
776  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
777  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
778  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
779  *
780  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
781  *
782  */
783 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
784                                               const int8_t *rhs,
785                                               const int64_t *bias,
786                                               int16_t *dst,
787                                               const int32_t dst_multiplier,
788                                               const int32_t dst_shift,
789                                               const int32_t rhs_cols,
790                                               const int32_t rhs_rows,
791                                               const int32_t activation_min,
792                                               const int32_t activation_max);
793 
794 /**
795  * @brief s16 Vector by s16 Matrix (transposed) multiplication
796  *
797  * @param[in]      lhs             Input left-hand side vector
798  * @param[in]      rhs             Input right-hand side matrix (transposed)
799  * @param[in]      bias            Input bias
800  * @param[out]     dst             Output vector
801  * @param[in]      dst_multiplier  Output multiplier
802  * @param[in]      dst_shift       Output shift
803  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
804  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
805  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
806  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
807  *
808  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
809  *
810  */
811 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16_s16(const int16_t *lhs,
812                                                   const int16_t *rhs,
813                                                   const int64_t *bias,
814                                                   int16_t *dst,
815                                                   const int32_t dst_multiplier,
816                                                   const int32_t dst_shift,
817                                                   const int32_t rhs_cols,
818                                                   const int32_t rhs_rows,
819                                                   const int32_t activation_min,
820                                                   const int32_t activation_max);
821 
822 /**
823  * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
824  *
825  * @param[in]      lhs             Input left-hand side vector
826  * @param[in]      rhs             Input right-hand side matrix (transposed)
827  * @param[out]     dst             Output vector
828  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side
829  *                                 vector. Range: -127 to 128
830  * @param[in]      scatter_offset  Address offset for dst. First output is stored at 'dst', the
831  *                                 second at 'dst + scatter_offset' and so on.
832  * @param[in]      dst_multiplier  Output multiplier
833  * @param[in]      dst_shift       Output shift
834  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
835  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
836  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
837  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
838  *
839  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
840  *
841  */
842 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
843                                                   const int8_t *rhs,
844                                                   int16_t *dst,
845                                                   const int32_t lhs_offset,
846                                                   const int32_t scatter_offset,
847                                                   const int32_t dst_multiplier,
848                                                   const int32_t dst_shift,
849                                                   const int32_t rhs_cols,
850                                                   const int32_t rhs_rows,
851                                                   const int32_t activation_min,
852                                                   const int32_t activation_max);
853 
854 /**
855  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
856  *        the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
857  *
858  * @param[in]      lhs             Input left-hand side matrix
859  * @param[in]      rhs             Input right-hand side matrix (transposed)
860  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
861  * @param[in]      active_ch       Subset of total_ch processed
862  * @param[in]      total_ch        Number of channels in LHS/RHS
863  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels
864  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels
865  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
866  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
867  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
868  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
869  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels
870  * @param[in]      out             Output pointer
871  *
872  * @return         The function returns one of the two
873  *                  - Updated output pointer if an implementation is available
874  *                  - NULL if no implementation is available.
875  *
876  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
877  * out for the following.
878  *                  - Output shift
879  *                  - Output multiplier
880  *                  - Output bias
881  *                  - rhs
882  */
883 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_padded_s8(const int8_t *lhs,
884                                                          const int8_t *rhs,
885                                                          const int32_t lhs_offset,
886                                                          const int32_t active_ch,
887                                                          const int32_t total_ch,
888                                                          const int32_t *out_shift,
889                                                          const int32_t *out_mult,
890                                                          const int32_t out_offset,
891                                                          const int32_t activation_min,
892                                                          const int32_t activation_max,
893                                                          const uint16_t row_x_col,
894                                                          const int32_t *const output_bias,
895                                                          int8_t *out);
896 
897 /**
898  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
899  *        Dimensions are the same for lhs and rhs.
900  *
901  * @param[in]      lhs             Input left-hand side matrix
902  * @param[in]      rhs             Input right-hand side matrix (transposed)
903  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
904  * @param[in]      active_ch       Subset of total_ch processed
905  * @param[in]      total_ch        Number of channels in LHS/RHS
906  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
907  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
908  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
909  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
910  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
911  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
912  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
913  * @param[in]      out             Output pointer
914  *
915  * @return         The function returns one of the two
916  *                  - Updated output pointer if an implementation is available
917  *                  - NULL if no implementation is available.
918  *
919  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
920  * out for the following.
921  *                  - Output shift
922  *                  - Output multiplier
923  *                  - Output bias
924  *                  - rhs
925  */
926 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s8(const int8_t *lhs,
927                                                   const int8_t *rhs,
928                                                   const int32_t lhs_offset,
929                                                   const int32_t active_ch,
930                                                   const int32_t total_ch,
931                                                   const int32_t *out_shift,
932                                                   const int32_t *out_mult,
933                                                   const int32_t out_offset,
934                                                   const int32_t activation_min,
935                                                   const int32_t activation_max,
936                                                   const uint16_t row_x_col,
937                                                   const int32_t *const output_bias,
938                                                   int8_t *out);
939 
940 /**
941  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases. rhs
942  * consists of packed int4 data. Dimensions are the same for lhs and rhs.
943  *
944  * @param[in]      lhs             Input left-hand side matrix
945  * @param[in]      rhs             Input right-hand side matrix (transposed). Consists of int4 data packed in an int8
946  * buffer.
947  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
948  * @param[in]      active_ch       Subset of total_ch processed
949  * @param[in]      total_ch        Number of channels in LHS/RHS
950  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
951  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
952  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
953  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
954  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
955  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
956  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
957  * @param[in]      out             Output pointer
958  *
959  * @return         The function returns one of the two
960  *                  - Updated output pointer if an implementation is available
961  *                  - NULL if no implementation is available.
962  *
963  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
964  * out for the following.
965  *                  - Output shift
966  *                  - Output multiplier
967  *                  - Output bias
968  *                  - rhs
969  */
970 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s4(const int8_t *lhs,
971                                                   const int8_t *rhs,
972                                                   const int32_t lhs_offset,
973                                                   const int32_t active_ch,
974                                                   const int32_t total_ch,
975                                                   const int32_t *out_shift,
976                                                   const int32_t *out_mult,
977                                                   const int32_t out_offset,
978                                                   const int32_t activation_min,
979                                                   const int32_t activation_max,
980                                                   const uint16_t row_x_col,
981                                                   const int32_t *const output_bias,
982                                                   int8_t *out);
983 
984 /**
985  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
986  *        Dimensions are the same for lhs and rhs.
987  *
988  * @param[in]      lhs             Input left-hand side matrix
989  * @param[in]      rhs             Input right-hand side matrix (transposed)
990  * @param[in]      num_ch          Number of channels in LHS/RHS
991  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
992  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
993  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
994  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
995  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
996  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
997  * @param[in]      out             Output pointer
998  *
999  * @return         The function returns one of the two
1000  *                  - Updated output pointer if an implementation is available
1001  *                  - NULL if no implementation is available.
1002  *
1003  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
1004  * out for the following.
1005  *                  - Output shift
1006  *                  - Output multiplier
1007  *                  - Output bias
1008  *                  - rhs
1009  */
1010 int16_t *arm_nn_depthwise_conv_nt_t_s16(const int16_t *lhs,
1011                                         const int8_t *rhs,
1012                                         const uint16_t num_ch,
1013                                         const int32_t *out_shift,
1014                                         const int32_t *out_mult,
1015                                         const int32_t activation_min,
1016                                         const int32_t activation_max,
1017                                         const uint16_t row_x_col,
1018                                         const int64_t *const output_bias,
1019                                         int16_t *out);
1020 
1021 /**
1022  * @brief Row of s8 scalars multiplicated with a s8 matrix ad accumulated into a s32 rolling scratch buffer.
1023  * Helpfunction for transposed convolution.
1024  *
1025  * @param[in]      lhs             Input left-hand side scalars
1026  * @param[in]      rhs             Input right-hand side matrix
1027  * @param[out]     output_start    Output buffer start
1028  * @param[in]      output_index    Output buffer current index
1029  * @param[in]      output_max      Output buffer size
1030  * @param[in]      rhs_rows        Number of rows in rhs matrix
1031  * @param[in]      rhs_cols        Number of columns in rhs matrix
1032  * @param[in]      input_channels  Number of input channels
1033  * @param[in]      output_channels Number of output channels
1034  * @param[in]      lhs_offset      Offset added to lhs before multiplication
1035  * @param[in]      row_offset      Address offset between each row of data output
1036  * @param[in]      input_x         Length of lhs scalar row.
1037  * @param[in]      stride_x        Address offset between each scalar-matrix multiplication result.
1038  * @param[in]      skip_row_top    Skip rows on top of the filter, used for padding.
1039  * @param[in]      skip_row_bottom Skip rows in the bottom of the filter, used for padding.
1040  *
1041  * @return         The function returns ARM_CMSIS_NN_SUCCESS
1042  *
1043  * @note           Rolling buffer refers to how the function wraps around the scratch buffer, e.g. it starts writing at
1044  * [output_start + output_index], writes to [output_start + output_max] and then continues at [output_start] again.
1045  */
1046 arm_cmsis_nn_status arm_nn_transpose_conv_row_s8_s32(const int8_t *lhs,
1047                                                      const int8_t *rhs,
1048                                                      int32_t *output_start,
1049                                                      const int32_t output_index,
1050                                                      const int32_t output_max,
1051                                                      const int32_t rhs_rows,
1052                                                      const int32_t rhs_cols,
1053                                                      const int32_t input_channels,
1054                                                      const int32_t output_channels,
1055                                                      const int32_t lhs_offset,
1056                                                      const int32_t row_offset,
1057                                                      const int32_t input_x,
1058                                                      const int32_t stride_x,
1059                                                      const int32_t skip_row_top,
1060                                                      const int32_t skip_row_bottom);
1061 
1062 /**
1063   @brief         Read 2 s16 elements and post increment pointer.
1064   @param[in]     in_q15   Pointer to pointer that holds address of input.
1065   @return        q31 value
1066  */
arm_nn_read_q15x2_ia(const int16_t ** in_q15)1067 __STATIC_FORCEINLINE int32_t arm_nn_read_q15x2_ia(const int16_t **in_q15)
1068 {
1069     int32_t val;
1070 
1071     memcpy(&val, *in_q15, 4);
1072     *in_q15 += 2;
1073 
1074     return (val);
1075 }
1076 
1077 /**
1078   @brief         Read 4 s8 from s8 pointer and post increment pointer.
1079   @param[in]     in_s8       Pointer to pointer that holds address of input.
1080   @return        q31 value
1081  */
arm_nn_read_s8x4_ia(const int8_t ** in_s8)1082 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4_ia(const int8_t **in_s8)
1083 {
1084     int32_t val;
1085     memcpy(&val, *in_s8, 4);
1086     *in_s8 += 4;
1087 
1088     return (val);
1089 }
1090 
1091 /**
1092   @brief         Read 2 s8 from s8 pointer and post increment pointer.
1093   @param[in]     in_s8    Pointer to pointer that holds address of input.
1094   @return        q31      value
1095  */
arm_nn_read_s8x2_ia(const int8_t ** in_s8)1096 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x2_ia(const int8_t **in_s8)
1097 {
1098     int32_t val;
1099     memcpy(&val, *in_s8, 2);
1100     *in_s8 += 2;
1101 
1102     return (val);
1103 }
1104 
1105 /**
1106   @brief         Read 2 int16 values from int16 pointer.
1107   @param[in]     in     pointer to address of input.
1108   @return        s32    value
1109  */
arm_nn_read_s16x2(const int16_t * in)1110 __STATIC_FORCEINLINE int32_t arm_nn_read_s16x2(const int16_t *in)
1111 {
1112     int32_t val;
1113     memcpy(&val, in, 4);
1114 
1115     return (val);
1116 }
1117 
1118 /**
1119   @brief         Read 4 s8 values.
1120   @param[in]     in_s8       pointer to address of input.
1121   @return        s32 value
1122  */
arm_nn_read_s8x4(const int8_t * in_s8)1123 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4(const int8_t *in_s8)
1124 {
1125     int32_t val;
1126     memcpy(&val, in_s8, 4);
1127 
1128     return (val);
1129 }
1130 /**
1131   @brief         Read 2 s8 values.
1132   @param[in]     in_s8    pointer to address of input.
1133   @return        s32      value
1134  */
arm_nn_read_s8x2(const int8_t * in_s8)1135 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x2(const int8_t *in_s8)
1136 {
1137     int32_t val;
1138     memcpy(&val, in_s8, 2);
1139 
1140     return (val);
1141 }
1142 
1143 /**
1144   @brief         Write four s8 to s8 pointer and increment pointer afterwards.
1145   @param[in]     in       Double pointer to input value
1146   @param[in]     value    Four bytes to copy
1147  */
arm_nn_write_s8x4_ia(int8_t ** in,int32_t value)1148 __STATIC_FORCEINLINE void arm_nn_write_s8x4_ia(int8_t **in, int32_t value)
1149 {
1150     memcpy(*in, &value, 4);
1151     *in += 4;
1152 }
1153 
1154 /**
1155  * @brief           memset optimized for MVE
1156  * @param[in, out]  dst         Destination pointer
1157  * @param[in]       val         Value to set
1158  * @param[in]       block_size  Number of bytes to copy.
1159  *
1160  */
arm_memset_s8(int8_t * dst,const int8_t val,uint32_t block_size)1161 __STATIC_FORCEINLINE void arm_memset_s8(int8_t *dst, const int8_t val, uint32_t block_size)
1162 {
1163 #if defined(ARM_MATH_MVEI)
1164     __asm volatile("   vdup.8                  q0, %[set_val]             \n"
1165                    "   wlstp.8                 lr, %[cnt], 1f             \n"
1166                    "2:                                                    \n"
1167                    "   vstrb.8                 q0, [%[in]], #16            \n"
1168                    "   letp                    lr, 2b                     \n"
1169                    "1:                                                    \n"
1170                    : [in] "+r"(dst)
1171                    : [cnt] "r"(block_size), [set_val] "r"(val)
1172                    : "q0", "memory", "r14");
1173 #else
1174     memset(dst, val, block_size);
1175 #endif
1176 }
1177 
1178 #if defined(ARM_MATH_DSP)
1179 
1180 /**
1181  * @brief read and expand one s4 word into two s8 words.
1182  */
read_and_pad_s4(const int8_t * source,int32_t * out1,int32_t * out2)1183 __STATIC_FORCEINLINE void read_and_pad_s4(const int8_t *source, int32_t *out1, int32_t *out2)
1184 {
1185     int16_t in = arm_nn_read_s8x2(source);
1186     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1187 
1188     *out1 = SXTB16_RORn(__sxtb16(inA << 4), 4);
1189     *out2 = SXTB16_RORn(__sxtb16(inA), 4);
1190 }
1191 
1192 /**
1193  * @brief read and expand one s4 word into two s8 words.
1194  * @details   The s4 elements are not evenly aligned on the byte boundary, so 3 bytes need to be read instead of 2.
1195  *            In other words first nibble to read start at the middle of a byte.
1196  *            byte index, s4 element
1197  *            0,          s4_x
1198  *            0,          s4_0
1199  *            1,          s4_1
1200  *            1,          s4_2
1201  *            2,          s4_3
1202  *            2,          s4_x
1203  */
read_and_pad_s4_uneven(const int8_t * source,int32_t * out1,int32_t * out2)1204 __STATIC_FORCEINLINE void read_and_pad_s4_uneven(const int8_t *source, int32_t *out1, int32_t *out2)
1205 {
1206     int32_t inA1 = (source[0] & 0xFF) | ((source[1] & 0xFF) << 16);
1207     int32_t inA2 = (source[1] & 0xFF) | ((source[2] & 0xFF) << 16);
1208 
1209     *out1 = SXTB16_RORn(__sxtb16(inA2 << 4), 4);
1210     *out2 = SXTB16_RORn(__sxtb16(inA1), 4);
1211 }
1212 
1213 /**
1214  * @brief read and expand one s4 word into two s16 words with ordering.
1215  */
read_and_pad_s4_ordered(const int8_t * source,int32_t * out1,int32_t * out2)1216 __STATIC_FORCEINLINE void read_and_pad_s4_ordered(const int8_t *source, int32_t *out1, int32_t *out2)
1217 {
1218     int16_t in = arm_nn_read_s8x2(source);
1219     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1220     int32_t inAbuf1 = SXTB16_RORn(__sxtb16(inA), 4);
1221     int32_t inAbuf2 = SXTB16_RORn(__sxtb16(inA << 4), 4);
1222     #ifndef ARM_MATH_BIG_ENDIAN
1223     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1224     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1225     #else
1226     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1227     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1228     #endif
1229 }
1230 
1231 /**
1232  * @brief read and expand one s8 word into two s16 words with ordering.
1233  */
read_and_pad(const int8_t * source,int32_t * out1,int32_t * out2)1234 __STATIC_FORCEINLINE const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
1235 {
1236     int32_t inA = arm_nn_read_s8x4_ia(&source);
1237     int32_t inAbuf1 = SXTB16_RORn((uint32_t)inA, 8);
1238     int32_t inAbuf2 = SXTB16(inA);
1239 
1240     #ifndef ARM_MATH_BIG_ENDIAN
1241     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1242     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1243     #else
1244     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1245     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1246     #endif
1247 
1248     return source;
1249 }
1250 
1251 /**
1252  * @brief read and expand one s8 word into two s16 words with ordering and addition.
1253  */
read_pad_and_add_s8(const int8_t * source,int32_t * out1,int32_t * out2,const uint32_t add)1254 __STATIC_FORCEINLINE void read_pad_and_add_s8(const int8_t *source, int32_t *out1, int32_t *out2, const uint32_t add)
1255 {
1256     int32_t inA = arm_nn_read_s8x4(source);
1257     int32_t inAbuf1 = SXTAB16_RORn(add, (uint32_t)inA, 8);
1258     int32_t inAbuf2 = SXTAB16(add, inA);
1259 
1260     #ifndef ARM_MATH_BIG_ENDIAN
1261     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1262     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1263     #else
1264     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1265     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1266     #endif
1267 }
1268 
1269 /**
1270  * @brief read and expand two bytes into one word with ordering.
1271  */
read_and_pad_s8x2(const int8_t * source,int32_t * out)1272 __STATIC_FORCEINLINE void read_and_pad_s8x2(const int8_t *source, int32_t *out)
1273 {
1274     int16_t in = arm_nn_read_s8x2(source);
1275     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1276     *out = SXTB16(inA);
1277 }
1278 
1279 /**
1280  * @brief read and expand two bytes into one word with ordering and addition.
1281  */
read_pad_and_add_s8x2(const int8_t * source,int32_t * out,const uint32_t add)1282 __STATIC_FORCEINLINE void read_pad_and_add_s8x2(const int8_t *source, int32_t *out, const uint32_t add)
1283 {
1284     int16_t in = arm_nn_read_s8x2(source);
1285     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1286     *out = SXTAB16(add, inA);
1287 }
1288 
1289 /**
1290  * @brief read and expand one s8 word into two s16 words with no additional ordering.
1291  */
read_and_pad_reordered(const int8_t * source,int32_t * out1,int32_t * out2)1292 __STATIC_FORCEINLINE const int8_t *read_and_pad_reordered(const int8_t *source, int32_t *out1, int32_t *out2)
1293 {
1294     int32_t inA = arm_nn_read_s8x4_ia(&source);
1295     #ifndef ARM_MATH_BIG_ENDIAN
1296     *out2 = SXTB16(ROR((uint32_t)inA, 8));
1297     *out1 = SXTB16(inA);
1298     #else
1299     *out1 = SXTB16(ROR((uint32_t)inA, 8));
1300     *out2 = SXTB16(inA);
1301     #endif
1302 
1303     return source;
1304 }
1305 
1306 #endif
1307 
1308 /**
1309  * @brief Matrix-multiplication function for convolution with per-channel requantization and 4 bit weights.
1310  * @param[in]       input_a            pointer to operand A, int8 packed with 2x int4.
1311  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1312  * @param[in]       output_ch          number of rows of A
1313  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1314  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1315  * @param[in]       out_offset         output tensor offset.
1316  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1317  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1318  * @param[in]       num_col_a          number of columns of A
1319  * @param[in]       output_bias        per output channel bias. Range : int32
1320  * @param[in,out]   out_0              pointer to output
1321  * @return     The function returns one of the two
1322  *              1. The incremented output pointer for a successful operation or
1323  *              2. NULL if implementation is not available.
1324  *
1325  * @details   This function does the matrix multiplication of weight matrix for all output channels
1326  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1327  *            clamped in the range provided by activation min and max.
1328  *            Supported framework: TensorFlow Lite micro.
1329  */
1330 int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *input_a,
1331                                       const int16_t *input_b,
1332                                       const uint16_t output_ch,
1333                                       const int32_t *out_shift,
1334                                       const int32_t *out_mult,
1335                                       const int32_t out_offset,
1336                                       const int32_t activation_min,
1337                                       const int32_t activation_max,
1338                                       const int32_t num_col_a,
1339                                       const int32_t *const output_bias,
1340                                       int8_t *out_0);
1341 /**
1342  * @brief Matrix-multiplication function for convolution with per-channel requantization.
1343  * @param[in]       input_a            pointer to operand A
1344  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1345  * @param[in]       output_ch          number of rows of A
1346  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1347  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1348  * @param[in]       out_offset         output tensor offset.
1349  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1350  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1351  * @param[in]       num_col_a          number of columns of A
1352  * @param[in]       aligned_num_col_a  number of columns of A aligned by 4
1353  * @param[in]       output_bias        per output channel bias. Range : int32
1354  * @param[in,out]   out_0              pointer to output
1355  * @return     The function returns one of the two
1356  *              1. The incremented output pointer for a successful operation or
1357  *              2. NULL if implementation is not available.
1358  *
1359  * @details   This function does the matrix multiplication of weight matrix for all output channels
1360  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1361  *            clamped in the range provided by activation min and max.
1362  *            Supported framework: TensorFlow Lite micro.
1363  */
1364 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
1365                                       const int16_t *input_b,
1366                                       const uint16_t output_ch,
1367                                       const int32_t *out_shift,
1368                                       const int32_t *out_mult,
1369                                       const int32_t out_offset,
1370                                       const int16_t activation_min,
1371                                       const int16_t activation_max,
1372                                       const int32_t num_col_a,
1373                                       const int32_t aligned_num_col_a,
1374                                       const int32_t *const output_bias,
1375                                       int8_t *out_0);
1376 
1377 /**
1378  * @brief Matrix-multiplication function for convolution with per-channel requantization, supporting an address offset
1379  * between rows.
1380  * @param[in]       input_a            pointer to operand A
1381  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1382  * @param[in]       output_ch          number of rows of A
1383  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1384  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1385  * @param[in]       out_offset         output tensor offset.
1386  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1387  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1388  * @param[in]       num_col_a          number of columns of A
1389  * @param[in]       aligned_num_col_a  number of columns of A aligned by 4
1390  * @param[in]       output_bias        per output channel bias. Range : int32
1391  * @param[in]       row_address_offset address offset between rows in the output
1392  * @param[in,out]   out_0              pointer to output
1393  * @return     The function returns one of the two
1394  *              1. The incremented output pointer for a successful operation or
1395  *              2. NULL if implementation is not available.
1396  *
1397  * @details   This function does the matrix multiplication of weight matrix for all output channels
1398  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1399  *            clamped in the range provided by activation min and max.
1400  *
1401  *            This function is slighly less performant than arm_nn_mat_mult_kernel_s8_s16, but allows support for
1402  * grouped convolution. Supported framework: TensorFlow Lite micro.
1403  */
1404 int8_t *arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t *input_a,
1405                                                  const int16_t *input_b,
1406                                                  const uint16_t output_ch,
1407                                                  const int32_t *out_shift,
1408                                                  const int32_t *out_mult,
1409                                                  const int32_t out_offset,
1410                                                  const int16_t activation_min,
1411                                                  const int16_t activation_max,
1412                                                  const int32_t num_col_a,
1413                                                  const int32_t aligned_num_col_a,
1414                                                  const int32_t *const output_bias,
1415                                                  const int32_t row_address_offset,
1416                                                  int8_t *out_0);
1417 
1418 /**
1419  * @brief Common softmax function for s8 input and s8 or s16 output
1420  * @param[in]  input          Pointer to the input tensor
1421  * @param[in]  num_rows       Number of rows in the input tensor
1422  * @param[in]  row_size       Number of elements in each input row
1423  * @param[in]  mult           Input quantization multiplier
1424  * @param[in]  shift          Input quantization shift within the range [0, 31]
1425  * @param[in]  diff_min       Minimum difference with max in row. Used to check if
1426  *                            the quantized exponential operation can be performed
1427  * @param[in]  int16_output   Indicating s8 output if 0 else s16 output
1428  * @param[out] output         Pointer to the output tensor
1429  *
1430  * @note Supported framework: TensorFlow Lite micro (bit-accurate)
1431  *
1432  */
1433 void arm_nn_softmax_common_s8(const int8_t *input,
1434                               const int32_t num_rows,
1435                               const int32_t row_size,
1436                               const int32_t mult,
1437                               const int32_t shift,
1438                               const int32_t diff_min,
1439                               const bool int16_output,
1440                               void *output);
1441 
1442 /**
1443  * @brief macro for adding rounding offset
1444  */
1445 #ifndef ARM_NN_TRUNCATE
1446     #define NN_ROUND(out_shift) ((0x1 << out_shift) >> 1)
1447 #else
1448     #define NN_ROUND(out_shift) 0
1449 #endif
1450 
1451 // Macros for shortening quantization functions' names and avoid long lines
1452 #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
1453 #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
1454 #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
1455 
1456 #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
1457 #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
1458 
1459 #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
1460 #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
1461 
1462 /**
1463  * @brief           Saturating doubling high multiply. Result matches
1464  *                  NEON instruction VQRDMULH.
1465  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
1466  * @param[in]       m2        Multiplier. Range: {NN_Q31_MIN, NN_Q31_MAX}
1467  * @return          Result of multiplication.
1468  *
1469  */
arm_nn_doubling_high_mult(const int32_t m1,const int32_t m2)1470 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult(const int32_t m1, const int32_t m2)
1471 {
1472     int32_t result = 0;
1473     // Rounding offset to add for a right shift of 31
1474     int64_t mult = 1 << 30;
1475 
1476     if ((m1 < 0) ^ (m2 < 0))
1477     {
1478         mult = 1 - mult;
1479     }
1480     // Gets resolved as a SMLAL instruction
1481     mult = mult + (int64_t)m1 * m2;
1482 
1483     // Utilize all of the upper 32 bits. This is the doubling step
1484     // as well.
1485     result = (int32_t)(mult / (1ll << 31));
1486 
1487     if ((m1 == m2) && (m1 == (int32_t)NN_Q31_MIN))
1488     {
1489         result = NN_Q31_MAX;
1490     }
1491     return result;
1492 }
1493 
1494 /**
1495  * @brief           Doubling high multiply without saturation. This is intended
1496  *                  for requantization where the scale is a positive integer
1497  *
1498  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
1499  * @param[in]       m2        Multiplier Range: {NN_Q31_MIN, NN_Q31_MAX}
1500  * @return          Result of multiplication.
1501  * @note            The result of this matches that of neon instruction
1502  *                  VQRDMULH for m1 in range {NN_Q31_MIN, NN_Q31_MAX} and m2 in
1503  *                  range {NN_Q31_MIN + 1, NN_Q31_MAX}. Saturation occurs when
1504  *                  m1 equals m2 equals NN_Q31_MIN and that is not handled by
1505  *                  this function.
1506  *
1507  */
arm_nn_doubling_high_mult_no_sat(const int32_t m1,const int32_t m2)1508 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult_no_sat(const int32_t m1, const int32_t m2)
1509 {
1510     int32_t result = 0;
1511     union arm_nn_long_long mult;
1512 
1513     // Rounding offset to add for a right shift of 31
1514     mult.word.low = 1 << 30;
1515     mult.word.high = 0;
1516 
1517     // Gets resolved as a SMLAL instruction
1518     mult.long_long = mult.long_long + (int64_t)m1 * m2;
1519 
1520     // Utilize all of the upper 32 bits. This is the doubling step
1521     // as well.
1522     result = (int32_t)(mult.long_long >> 31);
1523 
1524     return result;
1525 }
1526 
1527 /**
1528  * @brief           Rounding divide by power of two.
1529  * @param[in]       dividend - Dividend
1530  * @param[in]       exponent - Divisor = power(2, exponent)
1531  *                             Range: [0, 31]
1532  * @return          Rounded result of division. Midpoint is rounded away from zero.
1533  *
1534  */
arm_nn_divide_by_power_of_two(const int32_t dividend,const int32_t exponent)1535 __STATIC_FORCEINLINE int32_t arm_nn_divide_by_power_of_two(const int32_t dividend, const int32_t exponent)
1536 {
1537     int32_t result = 0;
1538     const int32_t remainder_mask = (1 << exponent) - 1;
1539     int32_t remainder = remainder_mask & dividend;
1540 
1541     // Basic division
1542     result = dividend >> exponent;
1543 
1544     // Adjust 'result' for rounding (mid point away from zero)
1545     int32_t threshold = remainder_mask >> 1;
1546     if (result < 0)
1547     {
1548         threshold++;
1549     }
1550     if (remainder > threshold)
1551     {
1552         result++;
1553     }
1554 
1555     return result;
1556 }
1557 
1558 /**
1559  * @brief           Requantize a given value.
1560  * @details         Essentially returns (val * multiplier)/(2 ^ shift) with different rounding depending if
1561  *                  CMSIS_NN_USE_SINGLE_ROUNDING is defined or not.
1562  * @param[in]       val         Value to be requantized
1563  * @param[in]       multiplier  Multiplier. Range {NN_Q31_MIN + 1, Q32_MAX}
1564  * @param[in]       shift       Shift. Range: {-31, 30}
1565  *                              Default branch:
1566  *                                  If shift is positive left shift 'val * multiplier' with shift
1567  *                                  If shift is negative right shift 'val * multiplier' with abs(shift)
1568  *                              Single round branch:
1569  *                                  Input for total_shift in divide by '2 ^ total_shift'
1570  *
1571  * @return          Default branch:
1572  *                      Returns (val * multiplier) with rounding divided by (2 ^ shift) with rounding
1573  *                  Single round branch:
1574  *                      Returns (val * multiplier)/(2 ^ (31 - shift)) with rounding
1575  *
1576  */
arm_nn_requantize(const int32_t val,const int32_t multiplier,const int32_t shift)1577 __STATIC_FORCEINLINE int32_t arm_nn_requantize(const int32_t val, const int32_t multiplier, const int32_t shift)
1578 {
1579 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1580     const int64_t total_shift = 31 - shift;
1581     const int64_t new_val = val * (int64_t)multiplier;
1582 
1583     int32_t result = new_val >> (total_shift - 1);
1584     result = (result + 1) >> 1;
1585 
1586     return result;
1587 #else
1588     return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
1589                                          RIGHT_SHIFT(shift));
1590 #endif
1591 }
1592 
1593 /**
1594  * @brief           Requantize a given 64 bit value.
1595  * @param[in]       val                 Value to be requantized in the range {-(1<<47)} to {(1<<47) - 1}
1596  * @param[in]       reduced_multiplier  Reduced multiplier in the range {NN_Q31_MIN + 1, Q32_MAX} to {Q16_MIN + 1,
1597  * Q16_MAX}
1598  * @param[in]       shift               Left or right shift for 'val * multiplier' in the range {-31} to {7}
1599  *
1600  * @return          Returns (val * multiplier)/(2 ^ shift)
1601  *
1602  */
arm_nn_requantize_s64(const int64_t val,const int32_t reduced_multiplier,const int32_t shift)1603 __STATIC_FORCEINLINE int32_t arm_nn_requantize_s64(const int64_t val,
1604                                                    const int32_t reduced_multiplier,
1605                                                    const int32_t shift)
1606 {
1607     const int64_t new_val = val * reduced_multiplier;
1608 
1609     int32_t result = new_val >> (14 - shift); // 64->32 bit reduction
1610     result = (result + 1) >> 1;               // Last shift position and insert round
1611 
1612     return result;
1613 }
1614 
1615 /**
1616  * @brief           memcpy optimized for MVE
1617  * @param[in, out]  dst         Destination pointer
1618  * @param[in]       src         Source pointer.
1619  * @param[in]       block_size  Number of bytes to copy.
1620  *
1621  */
arm_memcpy_s8(int8_t * __RESTRICT dst,const int8_t * __RESTRICT src,uint32_t block_size)1622 __STATIC_FORCEINLINE void arm_memcpy_s8(int8_t *__RESTRICT dst, const int8_t *__RESTRICT src, uint32_t block_size)
1623 {
1624 #if defined(ARM_MATH_MVEI)
1625     __asm volatile("   wlstp.8                 lr, %[cnt], 1f             \n"
1626                    "2:                                                    \n"
1627                    "   vldrb.8                 q0, [%[in]], #16            \n"
1628                    "   vstrb.8                 q0, [%[out]], #16           \n"
1629                    "   letp                    lr, 2b                     \n"
1630                    "1:                                                    \n"
1631                    : [in] "+r"(src), [out] "+r"(dst)
1632                    : [cnt] "r"(block_size)
1633                    : "q0", "memory", "r14");
1634 #else
1635     memcpy(dst, src, block_size);
1636 #endif
1637 }
1638 
1639 /**
1640  * @brief           memcpy wrapper for int16
1641  * @param[in, out]  dst         Destination pointer
1642  * @param[in]       src         Source pointer.
1643  * @param[in]       block_size  Number of bytes to copy.
1644  *
1645  */
arm_memcpy_q15(int16_t * __RESTRICT dst,const int16_t * __RESTRICT src,uint32_t block_size)1646 __STATIC_FORCEINLINE void arm_memcpy_q15(int16_t *__RESTRICT dst, const int16_t *__RESTRICT src, uint32_t block_size)
1647 {
1648     memcpy(dst, src, block_size);
1649 }
1650 
1651 #if defined(ARM_MATH_MVEI)
1652 /**
1653  * @brief           Vector saturating doubling high multiply returning high half.
1654  * @param[in]       m1        Multiplicand
1655  * @param[in]       m2        Multiplier
1656  * @return          Result of multiplication.
1657  *
1658  */
arm_doubling_high_mult_mve(const int32x4_t m1,const int32_t m2)1659 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const int32_t m2)
1660 {
1661     return vqrdmulhq_n_s32(m1, m2);
1662 }
1663 
1664 /**
1665  * @brief           Vector rounding divide by power of two.
1666  * @param[in]       dividend - Dividend vector
1667  * @param[in]       exponent - Divisor = power(2, exponent)
1668  *                             Range: [0, 31]
1669  * @return          Rounded result of division. Midpoint is rounded away from zero.
1670  *
1671  */
arm_divide_by_power_of_two_mve(const int32x4_t dividend,const int32_t exponent)1672 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const int32_t exponent)
1673 {
1674     const int32x4_t shift = vdupq_n_s32(-exponent);
1675     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1676     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1677     return vrshlq_s32(fixed_up_dividend, shift);
1678 }
1679 
1680 /**
1681  * @brief           Requantize a given vector.
1682  * @param[in]       val         Vector to be requantized
1683  * @param[in]       multiplier  multiplier
1684  * @param[in]       shift       shift
1685  *
1686  * @return          Returns (val * multiplier)/(2 ^ shift) with different rounding. See arm_nn_requantize for detatails.
1687  *
1688  */
arm_requantize_mve(const int32x4_t val,const int32_t multiplier,const int32_t shift)1689 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const int32_t multiplier, const int32_t shift)
1690 {
1691     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1692     const int right_shift = MIN(-1, shift);
1693     const int left_shift = shift - right_shift;
1694 
1695     const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
1696     const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
1697 
1698     int32x4_t result = vqdmulhq_n_s32(vshlq_s32(val, left_shift_dup), multiplier);
1699     result = vrshlq_s32(result, right_shift_dup);
1700 
1701     return result;
1702     #else
1703     return arm_divide_by_power_of_two_mve(
1704         arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
1705     #endif
1706 }
1707 
1708 /**
1709  * @brief           Vector saturating doubling high multiply with predication returning high half.
1710  * @param[in]       m1        Multiplicand
1711  * @param[in]       m2        Multiplier
1712  * @param[in]       p         Vector predication mask
1713  * @param[in]       v_zero    Vector of zeroes for merging predication intrinsic
1714  * @return          Result of multiplication.
1715  *
1716  */
arm_doubling_high_mult_mve_pred(const int32x4_t m1,const int32_t m2,const mve_pred16_t p,const int32x4_t v_zero)1717 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_pred(const int32x4_t m1,
1718                                                                const int32_t m2,
1719                                                                const mve_pred16_t p,
1720                                                                const int32x4_t v_zero)
1721 {
1722     return vqrdmulhq_m_n_s32(v_zero, m1, m2, p);
1723 }
1724 
1725 /**
1726  * @brief           Vector rounding divide by power of two with predication.
1727  * @param[in]       dividend - Dividend vector
1728  * @param[in]       exponent - Divisor = power(2, exponent)
1729  *                             Range: [0, 31]
1730  * @param[in]       p        - Vector predication mask
1731  * @param[in]       v_zero   - Vector of zeroes for merging predication intrinsic
1732  * @return          Rounded result of division. Midpoint is rounded away from zero.
1733  *
1734  */
arm_divide_by_power_of_two_mve_pred(const int32x4_t dividend,const int32_t exponent,const mve_pred16_t p,const int32x4_t v_zero)1735 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_pred(const int32x4_t dividend,
1736                                                                    const int32_t exponent,
1737                                                                    const mve_pred16_t p,
1738                                                                    const int32x4_t v_zero)
1739 {
1740     const int32x4_t shift = vdupq_x_n_s32(-exponent, p);
1741     const int32x4_t fixup = vshrq_x_n_s32(vandq_x_s32(dividend, shift, p), 31, p);
1742     const int32x4_t fixed_up_dividend = vqaddq_m_s32(v_zero, dividend, fixup, p);
1743     return vrshlq_m_s32(v_zero, fixed_up_dividend, shift, p);
1744 }
1745 
1746 /**
1747  * @brief           Requantize a given vector with predication.
1748  * @param[in]       val         Vector to be requantized
1749  * @param[in]       multiplier  multiplier
1750  * @param[in]       shift       shift
1751  * @param[in]       p           Vector predication mask
1752  *
1753  * @return          Returns (val * multiplier)/(2 ^ shift)
1754  *
1755  */
arm_requantize_mve_pred(const int32x4_t val,const int32_t multiplier,const int32_t shift,const mve_pred16_t p)1756 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_pred(const int32x4_t val,
1757                                                        const int32_t multiplier,
1758                                                        const int32_t shift,
1759                                                        const mve_pred16_t p)
1760 {
1761     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1762     const int right_shift = MIN(-1, shift);
1763     const int left_shift = shift - right_shift;
1764     const int32x4_t v_zero = vcreateq_s32(0, 0);
1765 
1766     const int32x4_t left_shift_dup = vdupq_x_n_s32(left_shift, p);
1767     const int32x4_t right_shift_dup = vdupq_x_n_s32(right_shift, p);
1768 
1769     int32x4_t result = vqrdmulhq_m_n_s32(v_zero, vshlq_m_s32(v_zero, val, left_shift_dup, p), multiplier, p);
1770     result = vrshlq_m_s32(v_zero, result, right_shift_dup, p);
1771 
1772     return result;
1773     #else
1774     const int32x4_t v_zero = vcreateq_s32(0, 0);
1775     return arm_divide_by_power_of_two_mve_pred(
1776         arm_doubling_high_mult_mve_pred(
1777             vshlq_m_s32(v_zero, val, vdupq_x_n_s32(LEFT_SHIFT(shift), p), p), multiplier, p, v_zero),
1778         RIGHT_SHIFT(shift),
1779         p,
1780         v_zero);
1781     #endif
1782 }
1783 
arm_doubling_high_mult_mve_32x4(const int32x4_t m1,const int32x4_t m2)1784 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
1785 {
1786     return vqrdmulhq_s32(m1, m2);
1787 }
1788 
arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend,const int32x4_t exponent)1789 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
1790 {
1791     const int32x4_t shift = -exponent;
1792     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1793     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1794     return vrshlq_s32(fixed_up_dividend, shift);
1795 }
1796 
arm_requantize_mve_32x4(const int32x4_t val,const int32x4_t multiplier,const int32x4_t shift)1797 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
1798                                                        const int32x4_t multiplier,
1799                                                        const int32x4_t shift)
1800 {
1801     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1802     const int32x4_t right_shift = vminq_s32(vdupq_n_s32(-1), shift);
1803     const int32x4_t left_shift = vqsubq_s32(shift, right_shift);
1804 
1805     int32x4_t result = vqdmulhq_s32(vshlq_s32(val, left_shift), multiplier);
1806     result = vrshlq_s32(result, right_shift);
1807 
1808     return result;
1809     #else
1810     const int32x4_t zz = vdupq_n_s32(0);
1811     const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
1812 
1813     const int32x4_t left_shift = vpselq_s32(shift, zz, p);
1814     const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
1815 
1816     return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
1817                                                right_shift);
1818     #endif
1819 }
1820 #endif
1821 
1822 // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
1823 
arm_nn_exp_on_negative_values(int32_t val)1824 __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
1825 {
1826     int32_t mask = 0;
1827     int32_t shift = 24;
1828 
1829     const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
1830     const int32_t remainder = val_mod_minus_quarter - val;
1831     const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
1832     const int32_t x2 = MUL_SAT(x, x);
1833 
1834     int32_t result = 1895147668 +
1835         MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
1836 
1837 #define SELECT_IF_NON_ZERO(x)                                                                                          \
1838     {                                                                                                                  \
1839         mask = MASK_IF_NON_ZERO(remainder & (1 << shift++));                                                           \
1840         result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result);                                                  \
1841     }
1842 
1843     SELECT_IF_NON_ZERO(1672461947)
1844     SELECT_IF_NON_ZERO(1302514674)
1845     SELECT_IF_NON_ZERO(790015084)
1846     SELECT_IF_NON_ZERO(290630308)
1847     SELECT_IF_NON_ZERO(39332535)
1848     SELECT_IF_NON_ZERO(720401)
1849     SELECT_IF_NON_ZERO(242)
1850 
1851 #undef SELECT_IF_NON_ZERO
1852 
1853     mask = MASK_IF_ZERO(val);
1854     return SELECT_USING_MASK(mask, NN_Q31_MAX, result);
1855 }
1856 
arm_nn_mult_by_power_of_two(const int32_t val,const int32_t exp)1857 __STATIC_FORCEINLINE int32_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
1858 {
1859     const int32_t thresh = ((1 << (31 - exp)) - 1);
1860     int32_t result = val << exp;
1861     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), NN_Q31_MAX, result);
1862     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), NN_Q31_MIN, result);
1863     return result;
1864 }
1865 
arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)1866 __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
1867 {
1868     const int64_t sum = (int64_t)val + (int64_t)NN_Q31_MAX;
1869     const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
1870     int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
1871 
1872     const int32_t shift = (1 << 29);
1873     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1874     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1875     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1876 
1877     return MUL_POW2(x, 1);
1878 }
1879 
1880 /**
1881   @brief         Write 2 s16 elements and post increment pointer.
1882   @param[in]     dest_q15  Pointer to pointer that holds address of destination.
1883   @param[in]     src_q31   Input value to be written.
1884  */
arm_nn_write_q15x2_ia(int16_t ** dest_q15,int32_t src_q31)1885 __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(int16_t **dest_q15, int32_t src_q31)
1886 {
1887     int32_t val = src_q31;
1888 
1889     memcpy(*dest_q15, &val, 4);
1890     *dest_q15 += 2;
1891 }
1892 
1893 /**
1894   @brief         Write 2 s8 elements and post increment pointer.
1895   @param[in]     dst  Pointer to pointer that holds address of destination.
1896   @param[in]     src  Input value to be written.
1897  */
arm_nn_write_s8x2_ia(int8_t ** dst,int16_t src)1898 __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)
1899 {
1900     memcpy(*dst, &src, 2);
1901     *dst += 2;
1902 }
1903 
1904 // Support functions for LSTM
1905 /**
1906  * @brief Update LSTM function for an iteration step using s8 input and output, and s16 internally.
1907  *
1908  * @param[in]   data_in                         Data input pointer
1909  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1910  * @param[out]  hidden_out                      Hidden state/ recurrent output pointer
1911  * @param[in]   params                          Struct containg all information about the lstm operator, see
1912  * arm_nn_types.
1913  * @param[in]   buffers                         Struct containg pointers to all temporary scratch buffers needed for the
1914  * lstm operator, see arm_nn_types.
1915  * @param[in]   batch_offset                    Number of timesteps between consecutive batches.
1916  * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
1917  * For params->time major = false, all time steps are stored continously before the next batch, so
1918  * batch offset = params->time_steps.
1919  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1920 
1921  */
1922 arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
1923                                         const int8_t *hidden_in,
1924                                         int8_t *hidden_out,
1925                                         const cmsis_nn_lstm_params *params,
1926                                         cmsis_nn_lstm_context *buffers,
1927                                         const int32_t batch_offset);
1928 
1929 /**
1930  * @brief Update LSTM function for an iteration step using s16 input and output, and s16 internally.
1931  *
1932  * @param[in]   data_in                         Data input pointer
1933  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1934  * @param[out]  hidden_out                      Hidden state/ recurrent output pointer
1935  * @param[in]   params                          Struct containg all information about the lstm operator, see
1936  * arm_nn_types.
1937  * @param[in]   buffers                         Struct containg pointers to all temporary scratch buffers needed for the
1938  * lstm operator, see arm_nn_types.
1939  * @param[in]   batch_offset                    Number of timesteps between consecutive batches.
1940  * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
1941  * For params->time major = false, all time steps are stored continously before the next batch, so
1942  * batch offset = params->time_steps.
1943  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1944 
1945  */
1946 arm_cmsis_nn_status arm_nn_lstm_step_s16(const int16_t *data_in,
1947                                          const int16_t *hidden_in,
1948                                          int16_t *hidden_out,
1949                                          const cmsis_nn_lstm_params *params,
1950                                          cmsis_nn_lstm_context *buffers,
1951                                          const int32_t batch_offset);
1952 
1953 /**
1954  * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
1955  *
1956  * @param[in]   data_in                         Data input pointer
1957  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1958  * @param[in]   gate_data                       Struct containing all information about the gate caluclation, see
1959  * arm_nn_types.
1960  * @param[in]   params                          Struct containing all information about the lstm_operation, see
1961  * arm_nn_types
1962  * @param[out]  output                          Hidden state/ recurrent output pointer
1963  * @param[in]   batch_offset                    Number of timesteps between consecutive batches, see
1964  * arm_nn_lstm_step_s8.
1965  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1966  */
1967 arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in,
1968                                                       const int8_t *hidden_in,
1969                                                       const cmsis_nn_lstm_gate *gate_data,
1970                                                       const cmsis_nn_lstm_params *params,
1971                                                       int16_t *output,
1972                                                       const int32_t batch_offset);
1973 
1974 /**
1975  * @brief Updates a LSTM gate for an iteration step of LSTM function, int16x8_16 version.
1976  *
1977  * @param[in]   data_in                         Data input pointer
1978  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1979  * @param[in]   gate_data                       Struct containing all information about the gate caluclation, see
1980  * arm_nn_types.
1981  * @param[in]   params                          Struct containing all information about the lstm_operation, see
1982  * arm_nn_types
1983  * @param[out]  output                          Hidden state/ recurrent output pointer
1984  * @param[in]   batch_offset                    Number of timesteps between consecutive batches, see
1985  * arm_nn_lstm_step_s16.
1986  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1987  */
1988 arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s16(const int16_t *data_in,
1989                                                    const int16_t *hidden_in,
1990                                                    const cmsis_nn_lstm_gate *gate_data,
1991                                                    const cmsis_nn_lstm_params *params,
1992                                                    int16_t *output,
1993                                                    const int32_t batch_offset);
1994 
1995 /**
1996  * @brief The result of the multiplication is accumulated to the passed result buffer.
1997  * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
1998  * from each other).
1999  *
2000  * @param[in]   lhs              Batched vector
2001  * @param[in]   rhs              Weights - input matrix (H(Rows)xW(Columns))
2002  * @param[in]   effective_bias   Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
2003  * @param[out]  dst              Output
2004  * @param[in]   dst_multiplier   Multiplier for quantization
2005  * @param[in]   dst_shift        Shift for quantization
2006  * @param[in]   rhs_cols         Vector/matarix column length
2007  * @param[in]   rhs_rows         Row count of matrix
2008  * @param[in]   batches          Batch size
2009  * @param[in]   batch_offset     Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s8. Note
2010  that the output is always stored with sequential batches.
2011  * @return                       The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
2012 
2013  */
2014 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs,
2015                                                          const int8_t *rhs,
2016                                                          const int32_t *effective_bias,
2017                                                          int16_t *dst,
2018                                                          const int32_t dst_multiplier,
2019                                                          const int32_t dst_shift,
2020                                                          const int32_t rhs_cols,
2021                                                          const int32_t rhs_rows,
2022                                                          const int32_t batches,
2023                                                          const int32_t batch_offset);
2024 
2025 /**
2026  * @brief The result of the multiplication is accumulated to the passed result buffer.
2027  * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
2028  * from each other).
2029  *
2030  * @param[in]   lhs              Batched vector
2031  * @param[in]   rhs              Weights - input matrix (H(Rows)xW(Columns))
2032  * @param[in]   effective_bias   Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
2033  * @param[out]  dst              Output
2034  * @param[in]   dst_multiplier   Multiplier for quantization
2035  * @param[in]   dst_shift        Shift for quantization
2036  * @param[in]   rhs_cols         Vector/matarix column length
2037  * @param[in]   rhs_rows         Row count of matrix
2038  * @param[in]   batches          Batch size
2039  * @param[in]   batch_offset     Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s16.
2040  Note that the output is always stored with sequential batches.
2041  * @return                       The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
2042 
2043  */
2044 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
2045                                                       const int8_t *rhs,
2046                                                       const int64_t *effective_bias,
2047                                                       int16_t *dst,
2048                                                       const int32_t dst_multiplier,
2049                                                       const int32_t dst_shift,
2050                                                       const int32_t rhs_cols,
2051                                                       const int32_t rhs_rows,
2052                                                       const int32_t batches,
2053                                                       const int32_t batch_offset);
2054 
2055 /**
2056  * @brief s16 elementwise multiplication with s8 output
2057  * @param[in]       input_1_vect        pointer to input vector 1
2058  * @param[in]       input_2_vect        pointer to input vector 2
2059  * @param[in,out]   output              pointer to output vector
2060  * @param[in]       out_offset          output offset
2061  * @param[in]       out_mult            output multiplier
2062  * @param[in]       out_shift           output shift
2063  * @param[in]       block_size          number of samples per batch
2064  * @param[in]       batch_size          number of samples per batch
2065  * @param[in]       batch_offset        Number of timesteps between consecutive batches in output, see
2066  * arm_nn_lstm_step_s8. Note that it is assumed that the input is stored with sequential batches.
2067  * @return          The function returns ARM_CMSIS_NN_SUCCESS
2068  *
2069  * @details   Supported framework: TensorFlow Lite micro
2070  */
2071 arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
2072                                                const int16_t *input_2_vect,
2073                                                int8_t *output,
2074                                                const int32_t out_offset,
2075                                                const int32_t out_mult,
2076                                                const int32_t out_shift,
2077                                                const int32_t block_size,
2078                                                const int32_t batch_size,
2079                                                const int32_t batch_offset);
2080 
2081 /**
2082  * @brief s16 elementwise multiplication with s16 output
2083  * @param[in]       input_1_vect        pointer to input vector 1
2084  * @param[in]       input_2_vect        pointer to input vector 2
2085  * @param[in,out]   output              pointer to output vector
2086  * @param[in]       out_offset          output offset
2087  * @param[in]       out_mult            output multiplier
2088  * @param[in]       out_shift           output shift
2089  * @param[in]       block_size          number of samples per batch
2090  * @param[in]       batch_size          number of samples per batch
2091  * @param[in]       batch_offset        Number of timesteps between consecutive batches in output, see
2092  * arm_nn_lstm_step_s16. Note that it is assumed that the input is stored with sequential batches.
2093  * @return          The function returns ARM_CMSIS_NN_SUCCESS
2094  *
2095  * @details   Supported framework: TensorFlow Lite micro
2096  */
2097 arm_cmsis_nn_status arm_elementwise_mul_s16_batch_offset(const int16_t *input_1_vect,
2098                                                          const int16_t *input_2_vect,
2099                                                          int16_t *output,
2100                                                          const int32_t out_offset,
2101                                                          const int32_t out_mult,
2102                                                          const int32_t out_shift,
2103                                                          const int32_t block_size,
2104                                                          const int32_t batch_size,
2105                                                          const int32_t batch_offset);
2106 
2107 /**
2108  * @brief s16 elementwise multiplication. The result of the multiplication is accumulated to the passed result buffer.
2109  * @param[in]       input_1_vect        pointer to input vector 1
2110  * @param[in]       input_2_vect        pointer to input vector 2
2111  * @param[in]       input_1_offset      offset for input 1. Not used.
2112  * @param[in]       input_2_offset      offset for input 2. Not used.
2113  * @param[in,out]   output              pointer to output vector
2114  * @param[in]       out_offset          output offset. Not used.
2115  * @param[in]       out_mult            output multiplier
2116  * @param[in]       out_shift           output shift
2117  * @param[in]       out_activation_min  minimum value to clamp output to. Min: -32768
2118  * @param[in]       out_activation_max  maximum value to clamp output to. Max: 32767
2119  * @param[in]       block_size          number of samples
2120  * @return          The function returns ARM_CMSIS_NN_SUCCESS
2121  *
2122  * @details   Supported framework: TensorFlow Lite micro
2123  */
2124 arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect,
2125                                                 const int16_t *input_2_vect,
2126                                                 const int32_t input_1_offset,
2127                                                 const int32_t input_2_offset,
2128                                                 int16_t *output,
2129                                                 const int32_t out_offset,
2130                                                 const int32_t out_mult,
2131                                                 const int32_t out_shift,
2132                                                 const int32_t out_activation_min,
2133                                                 const int32_t out_activation_max,
2134                                                 const int32_t block_size);
2135 
2136 /**
2137  * @brief Check if a broadcast is required between 2 cmsis_nn_dims.
2138  * @param[in]       shape_1             pointer to input tensor 1
2139  * @param[in]       shape_2             pointer to input tensor 2
2140  * @return          The function returns 1 if a broadcast is required, or 0 if not.
2141  *
2142  * @details   Compares each dimension and returns 1 if any dimension does not match.
2143  *            This function does not check that broadcast rules are met.
2144  */
arm_check_broadcast_required(const cmsis_nn_dims * shape_1,const cmsis_nn_dims * shape_2)2145 __STATIC_FORCEINLINE int32_t arm_check_broadcast_required(const cmsis_nn_dims *shape_1, const cmsis_nn_dims *shape_2)
2146 {
2147     if ((shape_1->n != shape_2->n) || (shape_1->h != shape_2->h) || (shape_1->w != shape_2->w) ||
2148         (shape_1->c != shape_2->c))
2149     {
2150         return 1;
2151     }
2152 
2153     return 0;
2154 }
2155 
2156 #ifdef __cplusplus
2157 }
2158 #endif
2159 
2160 #endif /* ARM_NNSUPPORTFUNCTIONS_H */
2161