1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nnsupportfunctions.h
22 * Description: Public header file of support functions for CMSIS NN Library
23 *
24 * $Date: 23 Mars 2023
25 * $Revision: V.16.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 * -------------------------------------------------------------------- */
29
30 #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
31 #define _ARM_NNSUPPORTFUNCTIONS_H_
32
33 #include "Internal/arm_nn_compiler.h"
34 #include "arm_nn_math_types.h"
35 #include "arm_nn_types.h"
36
37 #include <stdbool.h>
38
39 #ifdef __cplusplus
40 extern "C" {
41 #endif
42
43 #define USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims) \
44 (dw_conv_params->ch_mult == 1 && dw_conv_params->dilation.w == 1 && dw_conv_params->dilation.h == 1 && \
45 filter_dims->w * filter_dims->h < 512)
46
47 #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
48 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
49 #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
50 #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
51 #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
52
53 #define MAX(A, B) ((A) > (B) ? (A) : (B))
54 #define MIN(A, B) ((A) < (B) ? (A) : (B))
55 #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
56 #define REDUCE_MULTIPLIER(_mult) ((_mult < 0x7FFF0000) ? ((_mult + (1 << 15)) >> 16) : 0x7FFF)
57
58 // Number of channels processed in a block for DW Conv(MVE)
59 // Requirement: Greater than 0 & less than 128
60 // This can be fine tuned to match number of input channels for best performance.
61 // A layer with lower number of channels than CH_IN_BLOCK_MVE will result in higher
62 // scratch buffer usage and a layer with higher number of channels than CH_IN_BLOCK_MVE
63 // will result in lower scratch buffer usage.
64 #define CH_IN_BLOCK_MVE (124)
65
66 /**
67 * @brief definition to pack four 8 bit values.
68 */
69 #define PACK_S8x4_32x1(v0, v1, v2, v3) \
70 ((((int32_t)(v0) << 0) & (int32_t)0x000000FF) | (((int32_t)(v1) << 8) & (int32_t)0x0000FF00) | \
71 (((int32_t)(v2) << 16) & (int32_t)0x00FF0000) | (((int32_t)(v3) << 24) & (int32_t)0xFF000000))
72
73 /**
74 * @brief definition to pack two 16 bit values.
75 */
76 #define PACK_Q15x2_32x1(v0, v1) (((int32_t)v0 & (int32_t)0xFFFF) | ((int32_t)v1 << 16))
77
78 /**
79 * @brief Union for SIMD access of q31/s16/s8 types
80 */
81 union arm_nnword
82 {
83 int32_t word;
84 /**< q31 type */
85 int16_t half_words[2];
86 /**< s16 type */
87 int8_t bytes[4];
88 /**< s8 type */
89 };
90
91 /**
92 * @brief Union for data type long long
93 */
94 struct arm_nn_double
95 {
96 uint32_t low;
97 int32_t high;
98 };
99
100 union arm_nn_long_long
101 {
102 int64_t long_long;
103 struct arm_nn_double word;
104 };
105
106 /**
107 * @defgroup groupSupport Private
108 *
109 * Internal Support functions. Not intended to be called direclty by a CMSIS-NN user.
110 *
111 */
112
113 /**
114 * @defgroup supportConversion Data Conversion
115 *
116 * Perform data type conversion in-between neural network operations
117 *
118 */
119
120 /**
121 * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
122 * @param[in] src pointer to the s8 input vector
123 * @param[out] dst pointer to the s16 output vector
124 * @param[in] block_size length of the input vector
125 * @param[in] offset s16 offset to be added to each input vector element.
126 *
127 * \par Description:
128 *
129 * Output elements are ordered.
130 * The equation used for the conversion process is:
131 *
132 * <pre>
133 * dst[n] = (int16_t) src[n] + offset; 0 <= n < block_size.
134 * </pre>
135 *
136 */
137 void arm_q7_to_q15_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
138
139 #if defined(ARM_MATH_DSP)
140 /**
141 * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
142 * @param[in] src pointer to the s8 input vector
143 * @param[out] dst pointer to the s16 output vector
144 * @param[in] block_size length of the input vector
145 * @param[in] offset s16 offset to be added to each input vector element.
146 *
147 * \par Description:
148 *
149 * No additonal ordering is done with the result that output elements are not in order.
150 * Instead of ABCD order will be ACBD.
151 * Note this is for processors with DSP extension only.
152 * The equation used for the conversion process is:
153 *
154 * <pre>
155 * dst[n - 0] = (int16_t) src[n - 0] + offset; 0 <= n < block_size.
156 * dst[n - 1] = (int16_t) src[n - 2] + offset; 0 <= n < block_size.
157 * dst[n - 2] = (int16_t) src[n - 1] + offset; 0 <= n < block_size.
158 * dst[n - 3] = (int16_t) src[n - 3] + offset; 0 <= n < block_size.
159 * </pre>
160 *
161 */
162 void arm_s8_to_s16_unordered_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
163 #endif
164
165 /**
166 * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
167 * @param[in] row pointer to row
168 * @param[in] col pointer to im2col buffer, always consists of 2 columns.
169 * @param[in] num_ch number of channels
170 * @param[in] out_shift pointer to per output channel requantization shift parameter.
171 * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
172 * @param[in] out_offset output tensor offset.
173 * @param[in] activation_min minimum value to clamp the output to. Range : int8
174 * @param[in] activation_max maximum value to clamp the output to. Range : int8
175 * @param[in] kernel_size number of elements in one column.
176 * @param[in] output_bias per output channel bias. Range : int32
177 * @param[out] out pointer to output
178 * @return The function returns one of the two
179 * 1. The incremented output pointer for a successful operation or
180 * 2. NULL if implementation is not available.
181 *
182 * @details Supported framework: TensorFlow Lite micro.
183 */
184 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
185 const int16_t *col,
186 const uint16_t num_ch,
187 const int32_t *out_shift,
188 const int32_t *out_mult,
189 const int32_t out_offset,
190 const int32_t activation_min,
191 const int32_t activation_max,
192 const uint16_t kernel_size,
193 const int32_t *const output_bias,
194 int8_t *out);
195
196 /**
197 * @brief General Matrix-multiplication function with per-channel requantization.
198 * @param[in] input_row pointer to row operand
199 * @param[in] input_col pointer to col operand
200 * @param[in] output_ch number of rows of input_row
201 * @param[in] col_batches number of column batches. Range: 1 to 4
202 * @param[in] output_shift pointer to per output channel requantization shift parameter.
203 * @param[in] output_mult pointer to per output channel requantization multiplier parameter.
204 * @param[in] out_offset output tensor offset.
205 * @param[in] col_offset input tensor(col) offset.
206 * @param[in] row_offset kernel offset(row). Not used.
207 * @param[in] out_activation_min minimum value to clamp the output to. Range : int8
208 * @param[in] out_activation_max maximum value to clamp the output to. Range : int8
209 * @param[in] row_len number of elements in each row
210 * @param[in] bias per output channel bias. Range : int32
211 * @param[in,out] out pointer to output
212 * @return The function returns one of the two
213 * 1. The incremented output pointer for a successful operation or
214 * 2. NULL if implementation is not available.
215 *
216 * @details Supported framework: TensorFlow Lite
217 */
218 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
219 const int8_t *input_col,
220 const uint16_t output_ch,
221 const uint16_t col_batches,
222 const int32_t *output_shift,
223 const int32_t *output_mult,
224 const int32_t out_offset,
225 const int32_t col_offset,
226 const int32_t row_offset,
227 const int16_t out_activation_min,
228 const int16_t out_activation_max,
229 const uint16_t row_len,
230 const int32_t *const bias,
231 int8_t *out);
232 /**
233 * @brief Matrix-multiplication function for convolution with per-channel requantization for 16 bits convolution.
234 * @param[in] input_a pointer to operand A
235 * @param[in] input_b pointer to operand B, always consists of 2 vectors.
236 * @param[in] output_ch number of rows of A
237 * @param[in] out_shift pointer to per output channel requantization shift parameter.
238 * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
239 * @param[in] activation_min minimum value to clamp the output to. Range : int16
240 * @param[in] activation_max maximum value to clamp the output to. Range : int16
241 * @param[in] num_col_a number of columns of A
242 * @param[in] output_bias per output channel bias. Range : int64
243 * @param[in,out] out_0 pointer to output
244 * @return The function returns one of the two
245 * 1. The incremented output pointer for a successful operation or
246 * 2. NULL if implementation is not available.
247 *
248 * @details This function does the matrix multiplication of weight matrix for all output channels
249 * with 2 columns from im2col and produces two elements/output_channel. The outputs are
250 * clamped in the range provided by activation min and max.
251 * Supported framework: TensorFlow Lite micro.
252 */
253 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
254 const int16_t *input_b,
255 const int32_t output_ch,
256 const int32_t *out_shift,
257 const int32_t *out_mult,
258 const int16_t activation_min,
259 const int16_t activation_max,
260 const int32_t num_col_a,
261 const int64_t *const output_bias,
262 int16_t *out_0);
263
264 /**
265 * @brief General Vector by Matrix multiplication with requantization and storage of result.
266 * @param[in] row_elements number of row elements
267 * @param[in] skipped_row_elements number of row elements skipped due to padding.
268 * row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
269 * @param[in] row_base_ref pointer to row operand
270 * @param[in] col_base_ref pointer to col operand
271 * @param[out] out_ch Number of output channels
272 * @param[in] conv_params Pointer to convolution parameters like offsets and activation values
273 * @param[in] quant_params Pointer to per-channel quantization parameters
274 * @param[in] bias Pointer to optional per-channel bias
275 * @param[out] output Pointer to output where int8 results are stored.
276 * @return The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
277 * scaled result is stored in memory.
278 *
279 * @details Pseudo-code
280 * *output = 0
281 * sum_col = 0
282 * for (j = 0; j < out_ch; j++)
283 * for (i = 0; i < row_elements; i++)
284 * *output += row_base_ref[i] * col_base_ref[i]
285 * sum_col += col_base_ref[i]
286 * scale sum_col using quant_params and bias
287 * store result in 'output'
288 *
289 *
290 */
291 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
292 const int32_t skipped_row_elements,
293 const int8_t *row_base_ref,
294 const int8_t *col_base_ref,
295 const int32_t out_ch,
296 const cmsis_nn_conv_params *conv_params,
297 const cmsis_nn_per_channel_quant_params *quant_params,
298 const int32_t *bias,
299 int8_t *output);
300
301 /**
302 * @brief Matrix-multiplication with requantization & activation function for four rows and one column
303 * @param[in] row_elements number of row elements
304 * @param[in] offset offset between rows. Can be the same as row_elements.
305 * For e.g, in a 1x1 conv scenario with stride as 1.
306 * @param[in] row_base pointer to row operand
307 * @param[in] col_base pointer to col operand
308 * @param[in] out_ch Number of output channels
309 * @param[in] conv_params Pointer to convolution parameters like offsets and activation values
310 * @param[in] quant_params Pointer to per-channel quantization parameters
311 * @param[in] bias Pointer to per-channel bias
312 * @param[out] output Pointer to output where int8 results are stored.
313 *
314 * @return The function returns the updated output pointer or NULL if implementation is not available.
315 *
316 * @details Compliant to TFLM int8 specification. MVE implementation only
317 */
318 int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
319 const int32_t offset,
320 const int8_t *row_base,
321 const int8_t *col_base,
322 const int32_t out_ch,
323 const cmsis_nn_conv_params *conv_params,
324 const cmsis_nn_per_channel_quant_params *quant_params,
325 const int32_t *bias,
326 int8_t *output);
327
328 /**
329 * @brief General Matrix-multiplication function with per-channel requantization.
330 * This function assumes:
331 * - LHS input matrix NOT transposed (nt)
332 * - RHS input matrix transposed (t)
333 *
334 * @note This operation also performs the broadcast bias addition before the requantization
335 *
336 * @param[in] lhs Pointer to the LHS input matrix
337 * @param[in] rhs Pointer to the RHS input matrix
338 * @param[in] bias Pointer to the bias vector. The length of this vector is equal to the number of
339 * output columns (or RHS input rows)
340 * @param[out] dst Pointer to the output matrix with "m" rows and "n" columns
341 * @param[in] dst_multipliers Pointer to the multipliers vector needed for the per-channel requantization.
342 * The length of this vector is equal to the number of output columns (or RHS input
343 * rows)
344 * @param[in] dst_shifts Pointer to the shifts vector needed for the per-channel requantization. The length
345 * of this vector is equal to the number of output columns (or RHS input rows)
346 * @param[in] lhs_rows Number of LHS input rows
347 * @param[in] rhs_rows Number of RHS input rows
348 * @param[in] rhs_cols Number of LHS/RHS input columns
349 * @param[in] lhs_offset Offset to be applied to the LHS input value
350 * @param[in] dst_offset Offset to be applied the output result
351 * @param[in] activation_min Minimum value to clamp down the output. Range : int8
352 * @param[in] activation_max Maximum value to clamp up the output. Range : int8
353 * @param[in] lhs_cols_offset Column offset between subsequent lhs_rows
354 *
355 * @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
356 *
357 */
358 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
359 const int8_t *rhs,
360 const int32_t *bias,
361 int8_t *dst,
362 const int32_t *dst_multipliers,
363 const int32_t *dst_shifts,
364 const int32_t lhs_rows,
365 const int32_t rhs_rows,
366 const int32_t rhs_cols,
367 const int32_t lhs_offset,
368 const int32_t dst_offset,
369 const int32_t activation_min,
370 const int32_t activation_max,
371 const int32_t lhs_cols_offset);
372
373 /**
374 * @brief s8 Vector by Matrix (transposed) multiplication
375 *
376 * @param[in] lhs Input left-hand side vector
377 * @param[in] rhs Input right-hand side matrix (transposed)
378 * @param[in] bias Input bias
379 * @param[out] dst Output vector
380 * @param[in] lhs_offset Offset to be added to the input values of the left-hand side vector.
381 * Range: -127 to 128
382 * @param[in] dst_offset Offset to be added to the output values. Range: -127 to 128
383 * @param[in] dst_multiplier Output multiplier
384 * @param[in] dst_shift Output shift
385 * @param[in] rhs_cols Number of columns in the right-hand side input matrix
386 * @param[in] rhs_rows Number of rows in the right-hand side input matrix
387 * @param[in] activation_min Minimum value to clamp the output to. Range: int8
388 * @param[in] activation_max Maximum value to clamp the output to. Range: int8
389 * @param[in] address_offset Memory position offset for dst. First output is stored at 'dst', the
390 * second at 'dst + address_offset' and so on. Default value is typically 1.
391 *
392 * @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
393 *
394 */
395 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
396 const int8_t *rhs,
397 const int32_t *bias,
398 int8_t *dst,
399 const int32_t lhs_offset,
400 const int32_t dst_offset,
401 const int32_t dst_multiplier,
402 const int32_t dst_shift,
403 const int32_t rhs_cols,
404 const int32_t rhs_rows,
405 const int32_t activation_min,
406 const int32_t activation_max,
407 const int32_t address_offset);
408
409 /**
410 * @brief s16 Vector by Matrix (transposed) multiplication
411 *
412 * @param[in] lhs Input left-hand side vector
413 * @param[in] rhs Input right-hand side matrix (transposed)
414 * @param[in] bias Input bias
415 * @param[out] dst Output vector
416 * @param[in] dst_multiplier Output multiplier
417 * @param[in] dst_shift Output shift
418 * @param[in] rhs_cols Number of columns in the right-hand side input matrix
419 * @param[in] rhs_rows Number of rows in the right-hand side input matrix
420 * @param[in] activation_min Minimum value to clamp the output to. Range: int16
421 * @param[in] activation_max Maximum value to clamp the output to. Range: int16
422 *
423 * @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
424 *
425 */
426 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
427 const int8_t *rhs,
428 const int64_t *bias,
429 int16_t *dst,
430 const int32_t dst_multiplier,
431 const int32_t dst_shift,
432 const int32_t rhs_cols,
433 const int32_t rhs_rows,
434 const int32_t activation_min,
435 const int32_t activation_max);
436
437 /**
438 * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
439 *
440 * @param[in] lhs Input left-hand side vector
441 * @param[in] rhs Input right-hand side matrix (transposed)
442 * @param[out] dst Output vector
443 * @param[in] lhs_offset Offset to be added to the input values of the left-hand side
444 * vector. Range: -127 to 128
445 * @param[in] scatter_offset Address offset for dst. First output is stored at 'dst', the
446 * second at 'dst + scatter_offset' and so on.
447 * @param[in] dst_multiplier Output multiplier
448 * @param[in] dst_shift Output shift
449 * @param[in] rhs_cols Number of columns in the right-hand side input matrix
450 * @param[in] rhs_rows Number of rows in the right-hand side input matrix
451 * @param[in] activation_min Minimum value to clamp the output to. Range: int16
452 * @param[in] activation_max Maximum value to clamp the output to. Range: int16
453 *
454 * @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
455 *
456 */
457 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
458 const int8_t *rhs,
459 int16_t *dst,
460 const int32_t lhs_offset,
461 const int32_t scatter_offset,
462 const int32_t dst_multiplier,
463 const int32_t dst_shift,
464 const int32_t rhs_cols,
465 const int32_t rhs_rows,
466 const int32_t activation_min,
467 const int32_t activation_max);
468
469 /**
470 * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
471 * the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
472 *
473 * @param[in] lhs Input left-hand side matrix
474 * @param[in] rhs Input right-hand side matrix (transposed)
475 * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
476 * @param[in] active_ch Subset of total_ch processed
477 * @param[in] total_ch Number of channels in LHS/RHS
478 * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels
479 * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels
480 * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
481 * @param[in] activation_min Minimum value to clamp the output to. Range: int8
482 * @param[in] activation_max Maximum value to clamp the output to. Range: int8
483 * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
484 * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels
485 * @param[in] out Output pointer
486 *
487 * @return The function returns one of the two
488 * - Updated output pointer if an implementation is available
489 * - NULL if no implementation is available.
490 *
491 * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
492 * out for the following.
493 * - Output shift
494 * - Output multiplier
495 * - Output bias
496 * - rhs
497 */
498 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_padded_s8(const int8_t *lhs,
499 const int8_t *rhs,
500 const int32_t lhs_offset,
501 const int32_t active_ch,
502 const int32_t total_ch,
503 const int32_t *out_shift,
504 const int32_t *out_mult,
505 const int32_t out_offset,
506 const int32_t activation_min,
507 const int32_t activation_max,
508 const uint16_t row_x_col,
509 const int32_t *const output_bias,
510 int8_t *out);
511
512 /**
513 * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
514 * Dimensions are the same for lhs and rhs.
515 *
516 * @param[in] lhs Input left-hand side matrix
517 * @param[in] rhs Input right-hand side matrix (transposed)
518 * @param[in] lhs_offset LHS matrix offset(input offset). Range: -127 to 128
519 * @param[in] active_ch Subset of total_ch processed
520 * @param[in] total_ch Number of channels in LHS/RHS
521 * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels.
522 * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels.
523 * @param[in] out_offset Offset to be added to the output values. Range: -127 to 128
524 * @param[in] activation_min Minimum value to clamp the output to. Range: int8
525 * @param[in] activation_max Maximum value to clamp the output to. Range: int8
526 * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
527 * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels.
528 * @param[in] out Output pointer
529 *
530 * @return The function returns one of the two
531 * - Updated output pointer if an implementation is available
532 * - NULL if no implementation is available.
533 *
534 * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
535 * out for the following.
536 * - Output shift
537 * - Output multiplier
538 * - Output bias
539 * - rhs
540 */
541 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s8(const int8_t *lhs,
542 const int8_t *rhs,
543 const int32_t lhs_offset,
544 const int32_t active_ch,
545 const int32_t total_ch,
546 const int32_t *out_shift,
547 const int32_t *out_mult,
548 const int32_t out_offset,
549 const int32_t activation_min,
550 const int32_t activation_max,
551 const uint16_t row_x_col,
552 const int32_t *const output_bias,
553 int8_t *out);
554
555 /**
556 * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
557 * Dimensions are the same for lhs and rhs.
558 *
559 * @param[in] lhs Input left-hand side matrix
560 * @param[in] rhs Input right-hand side matrix (transposed)
561 * @param[in] num_ch Number of channels in LHS/RHS
562 * @param[in] out_shift Per channel output shift. Length of vector is equal to number of channels.
563 * @param[in] out_mult Per channel output multiplier. Length of vector is equal to number of channels.
564 * @param[in] activation_min Minimum value to clamp the output to. Range: int8
565 * @param[in] activation_max Maximum value to clamp the output to. Range: int8
566 * @param[in] row_x_col (row_dimension * col_dimension) of LHS/RHS matrix
567 * @param[in] output_bias Per channel output bias. Length of vector is equal to number of channels.
568 * @param[in] out Output pointer
569 *
570 * @return The function returns one of the two
571 * - Updated output pointer if an implementation is available
572 * - NULL if no implementation is available.
573 *
574 * @note If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
575 * out for the following.
576 * - Output shift
577 * - Output multiplier
578 * - Output bias
579 * - rhs
580 */
581 int16_t *arm_nn_depthwise_conv_nt_t_s16(const int16_t *lhs,
582 const int8_t *rhs,
583 const uint16_t num_ch,
584 const int32_t *out_shift,
585 const int32_t *out_mult,
586 const int32_t activation_min,
587 const int32_t activation_max,
588 const uint16_t row_x_col,
589 const int64_t *const output_bias,
590 int16_t *out);
591
592 /**
593 @brief Read 2 s16 elements and post increment pointer.
594 @param[in] in_q15 Pointer to pointer that holds address of input.
595 @return q31 value
596 */
arm_nn_read_q15x2_ia(const int16_t ** in_q15)597 __STATIC_FORCEINLINE int32_t arm_nn_read_q15x2_ia(const int16_t **in_q15)
598 {
599 int32_t val;
600
601 memcpy(&val, *in_q15, 4);
602 *in_q15 += 2;
603
604 return (val);
605 }
606
607 /**
608 @brief Read 4 s8 from s8 pointer and post increment pointer.
609 @param[in] in_s8 Pointer to pointer that holds address of input.
610 @return q31 value
611 */
arm_nn_read_s8x4_ia(const int8_t ** in_s8)612 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4_ia(const int8_t **in_s8)
613 {
614 int32_t val;
615 memcpy(&val, *in_s8, 4);
616 *in_s8 += 4;
617
618 return (val);
619 }
620
621 /**
622 @brief Read 2 int16 values from int16 pointer.
623 @param[in] in pointer to address of input.
624 @return s32 value
625 */
arm_nn_read_s16x2(const int16_t * in)626 __STATIC_FORCEINLINE int32_t arm_nn_read_s16x2(const int16_t *in)
627 {
628 int32_t val;
629 memcpy(&val, in, 4);
630
631 return (val);
632 }
633
634 /**
635 @brief Read 4 s8 values.
636 @param[in] in_s8 pointer to address of input.
637 @return s32 value
638 */
arm_nn_read_s8x4(const int8_t * in_s8)639 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4(const int8_t *in_s8)
640 {
641 int32_t val;
642 memcpy(&val, in_s8, 4);
643
644 return (val);
645 }
646
647 /**
648 @brief Write four s8 to s8 pointer and increment pointer afterwards.
649 @param[in] in Double pointer to input value
650 @param[in] value Four bytes to copy
651 */
arm_nn_write_s8x4_ia(int8_t ** in,int32_t value)652 __STATIC_FORCEINLINE void arm_nn_write_s8x4_ia(int8_t **in, int32_t value)
653 {
654 memcpy(*in, &value, 4);
655 *in += 4;
656 }
657
658 /**
659 * @brief memset optimized for MVE
660 * @param[in, out] dst Destination pointer
661 * @param[in] val Value to set
662 * @param[in] block_size Number of bytes to copy.
663 *
664 */
arm_memset_s8(int8_t * dst,const int8_t val,uint32_t block_size)665 __STATIC_FORCEINLINE void arm_memset_s8(int8_t *dst, const int8_t val, uint32_t block_size)
666 {
667 #if defined(ARM_MATH_MVEI)
668 __asm volatile(" vdup.8 q0, %[set_val] \n"
669 " wlstp.8 lr, %[cnt], 1f \n"
670 "2: \n"
671 " vstrb.8 q0, [%[in]], #16 \n"
672 " letp lr, 2b \n"
673 "1: \n"
674 : [in] "+r"(dst)
675 : [cnt] "r"(block_size), [set_val] "r"(val)
676 : "q0", "memory", "r14");
677 #else
678 memset(dst, val, block_size);
679 #endif
680 }
681
682 #if defined(ARM_MATH_DSP)
683
684 /**
685 * @brief read and expand one s8 word into two s16 words with ordering.
686 */
read_and_pad(const int8_t * source,int32_t * out1,int32_t * out2)687 __STATIC_FORCEINLINE const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
688 {
689 int32_t inA = arm_nn_read_s8x4_ia(&source);
690 int32_t inAbuf1 = SXTB16_RORn((uint32_t)inA, 8);
691 int32_t inAbuf2 = SXTB16(inA);
692
693 #ifndef ARM_MATH_BIG_ENDIAN
694 *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
695 *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
696 #else
697 *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
698 *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
699 #endif
700
701 return source;
702 }
703
704 /**
705 * @brief read and expand one s8 word into two s16 words with no additional ordering.
706 */
read_and_pad_reordered(const int8_t * source,int32_t * out1,int32_t * out2)707 __STATIC_FORCEINLINE const int8_t *read_and_pad_reordered(const int8_t *source, int32_t *out1, int32_t *out2)
708 {
709 int32_t inA = arm_nn_read_s8x4_ia(&source);
710 #ifndef ARM_MATH_BIG_ENDIAN
711 *out2 = SXTB16(ROR((uint32_t)inA, 8));
712 *out1 = SXTB16(inA);
713 #else
714 *out1 = SXTB16(ROR((uint32_t)inA, 8));
715 *out2 = SXTB16(inA);
716 #endif
717
718 return source;
719 }
720
721 #endif
722
723 /**
724 * @brief Matrix-multiplication function for convolution with per-channel requantization.
725 * @param[in] input_a pointer to operand A
726 * @param[in] input_b pointer to operand B, always consists of 2 vectors.
727 * @param[in] output_ch number of rows of A
728 * @param[in] out_shift pointer to per output channel requantization shift parameter.
729 * @param[in] out_mult pointer to per output channel requantization multiplier parameter.
730 * @param[in] out_offset output tensor offset.
731 * @param[in] activation_min minimum value to clamp the output to. Range : int8
732 * @param[in] activation_max maximum value to clamp the output to. Range : int8
733 * @param[in] num_col_a number of columns of A
734 * @param[in] output_bias per output channel bias. Range : int32
735 * @param[in,out] out_0 pointer to output
736 * @return The function returns one of the two
737 * 1. The incremented output pointer for a successful operation or
738 * 2. NULL if implementation is not available.
739 *
740 * @details This function does the matrix multiplication of weight matrix for all output channels
741 * with 2 columns from im2col and produces two elements/output_channel. The outputs are
742 * clamped in the range provided by activation min and max.
743 * Supported framework: TensorFlow Lite micro.
744 */
745 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
746 const int16_t *input_b,
747 const uint16_t output_ch,
748 const int32_t *out_shift,
749 const int32_t *out_mult,
750 const int32_t out_offset,
751 const int16_t activation_min,
752 const int16_t activation_max,
753 const int32_t num_col_a,
754 const int32_t *const output_bias,
755 int8_t *out_0);
756
757 /**
758 * @brief Common softmax function for s8 input and s8 or s16 output
759 * @param[in] input Pointer to the input tensor
760 * @param[in] num_rows Number of rows in the input tensor
761 * @param[in] row_size Number of elements in each input row
762 * @param[in] mult Input quantization multiplier
763 * @param[in] shift Input quantization shift within the range [0, 31]
764 * @param[in] diff_min Minimum difference with max in row. Used to check if
765 * the quantized exponential operation can be performed
766 * @param[in] int16_output Indicating s8 output if 0 else s16 output
767 * @param[out] output Pointer to the output tensor
768 *
769 * @note Supported framework: TensorFlow Lite micro (bit-accurate)
770 *
771 */
772 void arm_nn_softmax_common_s8(const int8_t *input,
773 const int32_t num_rows,
774 const int32_t row_size,
775 const int32_t mult,
776 const int32_t shift,
777 const int32_t diff_min,
778 const bool int16_output,
779 void *output);
780
781 /**
782 * @brief macro for adding rounding offset
783 */
784 #ifndef ARM_NN_TRUNCATE
785 #define NN_ROUND(out_shift) ((0x1 << out_shift) >> 1)
786 #else
787 #define NN_ROUND(out_shift) 0
788 #endif
789
790 // Macros for shortening quantization functions' names and avoid long lines
791 #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
792 #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
793 #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
794
795 #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
796 #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
797
798 #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
799 #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
800
801 /**
802 * @brief Saturating doubling high multiply. Result matches
803 * NEON instruction VQRDMULH.
804 * @param[in] m1 Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
805 * @param[in] m2 Multiplier. Range: {NN_Q31_MIN, NN_Q31_MAX}
806 * @return Result of multiplication.
807 *
808 */
arm_nn_doubling_high_mult(const int32_t m1,const int32_t m2)809 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult(const int32_t m1, const int32_t m2)
810 {
811 int32_t result = 0;
812 // Rounding offset to add for a right shift of 31
813 int64_t mult = 1 << 30;
814
815 if ((m1 < 0) ^ (m2 < 0))
816 {
817 mult = 1 - mult;
818 }
819 // Gets resolved as a SMLAL instruction
820 mult = mult + (int64_t)m1 * m2;
821
822 // Utilize all of the upper 32 bits. This is the doubling step
823 // as well.
824 result = (int32_t)(mult / (1ll << 31));
825
826 if ((m1 == m2) && (m1 == (int32_t)NN_Q31_MIN))
827 {
828 result = NN_Q31_MAX;
829 }
830 return result;
831 }
832
833 /**
834 * @brief Doubling high multiply without saturation. This is intended
835 * for requantization where the scale is a positive integer
836 *
837 * @param[in] m1 Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
838 * @param[in] m2 Multiplier Range: {NN_Q31_MIN, NN_Q31_MAX}
839 * @return Result of multiplication.
840 * @note The result of this matches that of neon instruction
841 * VQRDMULH for m1 in range {NN_Q31_MIN, NN_Q31_MAX} and m2 in
842 * range {NN_Q31_MIN + 1, NN_Q31_MAX}. Saturation occurs when
843 * m1 equals m2 equals NN_Q31_MIN and that is not handled by
844 * this function.
845 *
846 */
arm_nn_doubling_high_mult_no_sat(const int32_t m1,const int32_t m2)847 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult_no_sat(const int32_t m1, const int32_t m2)
848 {
849 int32_t result = 0;
850 union arm_nn_long_long mult;
851
852 // Rounding offset to add for a right shift of 31
853 mult.word.low = 1 << 30;
854 mult.word.high = 0;
855
856 // Gets resolved as a SMLAL instruction
857 mult.long_long = mult.long_long + (int64_t)m1 * m2;
858
859 // Utilize all of the upper 32 bits. This is the doubling step
860 // as well.
861 result = (int32_t)(mult.long_long >> 31);
862
863 return result;
864 }
865
866 /**
867 * @brief Rounding divide by power of two.
868 * @param[in] dividend - Dividend
869 * @param[in] exponent - Divisor = power(2, exponent)
870 * Range: [0, 31]
871 * @return Rounded result of division. Midpoint is rounded away from zero.
872 *
873 */
arm_nn_divide_by_power_of_two(const int32_t dividend,const int32_t exponent)874 __STATIC_FORCEINLINE int32_t arm_nn_divide_by_power_of_two(const int32_t dividend, const int32_t exponent)
875 {
876 int32_t result = 0;
877 const int32_t remainder_mask = (1 << exponent) - 1;
878 int32_t remainder = remainder_mask & dividend;
879
880 // Basic division
881 result = dividend >> exponent;
882
883 // Adjust 'result' for rounding (mid point away from zero)
884 int32_t threshold = remainder_mask >> 1;
885 if (result < 0)
886 {
887 threshold++;
888 }
889 if (remainder > threshold)
890 {
891 result++;
892 }
893
894 return result;
895 }
896
897 /**
898 * @brief Requantize a given value.
899 * @param[in] val Value to be requantized
900 * @param[in] multiplier multiplier. Range {NN_Q31_MIN + 1, Q32_MAX}
901 * @param[in] shift left or right shift for 'val * multiplier'
902 *
903 * @return Returns (val * multiplier)/(2 ^ shift)
904 *
905 */
arm_nn_requantize(const int32_t val,const int32_t multiplier,const int32_t shift)906 __STATIC_FORCEINLINE int32_t arm_nn_requantize(const int32_t val, const int32_t multiplier, const int32_t shift)
907 {
908 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
909 const int64_t total_shift = 31 - shift;
910 const int64_t new_val = val * (int64_t)multiplier;
911
912 int32_t result = new_val >> (total_shift - 1);
913 result = (result + 1) >> 1;
914
915 return result;
916 #else
917 return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
918 RIGHT_SHIFT(shift));
919 #endif
920 }
921
922 /**
923 * @brief Requantize a given 64 bit value.
924 * @param[in] val Value to be requantized in the range {-(1<<47)} to {(1<<47) - 1}
925 * @param[in] reduced_multiplier Reduced multiplier in the range {NN_Q31_MIN + 1, Q32_MAX} to {Q16_MIN + 1,
926 * Q16_MAX}
927 * @param[in] shift Left or right shift for 'val * multiplier' in the range {-31} to {7}
928 *
929 * @return Returns (val * multiplier)/(2 ^ shift)
930 *
931 */
arm_nn_requantize_s64(const int64_t val,const int32_t reduced_multiplier,const int32_t shift)932 __STATIC_FORCEINLINE int32_t arm_nn_requantize_s64(const int64_t val,
933 const int32_t reduced_multiplier,
934 const int32_t shift)
935 {
936 const int64_t new_val = val * reduced_multiplier;
937
938 int32_t result = new_val >> (14 - shift); // 64->32 bit reduction
939 result = (result + 1) >> 1; // Last shift position and insert round
940
941 return result;
942 }
943
944 /**
945 * @brief memcpy optimized for MVE
946 * @param[in, out] dst Destination pointer
947 * @param[in] src Source pointer.
948 * @param[in] block_size Number of bytes to copy.
949 *
950 */
arm_memcpy_s8(int8_t * __RESTRICT dst,const int8_t * __RESTRICT src,uint32_t block_size)951 __STATIC_FORCEINLINE void arm_memcpy_s8(int8_t *__RESTRICT dst, const int8_t *__RESTRICT src, uint32_t block_size)
952 {
953 #if defined(ARM_MATH_MVEI)
954 __asm volatile(" wlstp.8 lr, %[cnt], 1f \n"
955 "2: \n"
956 " vldrb.8 q0, [%[in]], #16 \n"
957 " vstrb.8 q0, [%[out]], #16 \n"
958 " letp lr, 2b \n"
959 "1: \n"
960 : [in] "+r"(src), [out] "+r"(dst)
961 : [cnt] "r"(block_size)
962 : "q0", "memory", "r14");
963 #else
964 memcpy(dst, src, block_size);
965 #endif
966 }
967
968 /**
969 * @brief memcpy wrapper for int16
970 * @param[in, out] dst Destination pointer
971 * @param[in] src Source pointer.
972 * @param[in] block_size Number of bytes to copy.
973 *
974 */
arm_memcpy_q15(int16_t * __RESTRICT dst,const int16_t * __RESTRICT src,uint32_t block_size)975 __STATIC_FORCEINLINE void arm_memcpy_q15(int16_t *__RESTRICT dst, const int16_t *__RESTRICT src, uint32_t block_size)
976 {
977 memcpy(dst, src, block_size);
978 }
979
980 #if defined(ARM_MATH_MVEI)
981 /**
982 * @brief Vector saturating doubling high multiply returning high half.
983 * @param[in] m1 Multiplicand
984 * @param[in] m2 Multiplier
985 * @return Result of multiplication.
986 *
987 */
arm_doubling_high_mult_mve(const int32x4_t m1,const int32_t m2)988 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const int32_t m2)
989 {
990 return vqrdmulhq_n_s32(m1, m2);
991 }
992
993 /**
994 * @brief Vector rounding divide by power of two.
995 * @param[in] dividend - Dividend vector
996 * @param[in] exponent - Divisor = power(2, exponent)
997 * Range: [0, 31]
998 * @return Rounded result of division. Midpoint is rounded away from zero.
999 *
1000 */
arm_divide_by_power_of_two_mve(const int32x4_t dividend,const int32_t exponent)1001 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const int32_t exponent)
1002 {
1003 const int32x4_t shift = vdupq_n_s32(-exponent);
1004 const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1005 const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1006 return vrshlq_s32(fixed_up_dividend, shift);
1007 }
1008
1009 /**
1010 * @brief Requantize a given vector.
1011 * @param[in] val Vector to be requantized
1012 * @param[in] multiplier multiplier
1013 * @param[in] shift shift
1014 *
1015 * @return Returns (val * multiplier)/(2 ^ shift)
1016 *
1017 */
arm_requantize_mve(const int32x4_t val,const int32_t multiplier,const int32_t shift)1018 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const int32_t multiplier, const int32_t shift)
1019 {
1020 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1021 const int right_shift = MIN(-1, shift);
1022 const int left_shift = shift - right_shift;
1023
1024 const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
1025 const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
1026
1027 int32x4_t result = vqdmulhq_n_s32(vshlq_s32(val, left_shift_dup), multiplier);
1028 result = vrshlq_s32(result, right_shift_dup);
1029
1030 return result;
1031 #else
1032 return arm_divide_by_power_of_two_mve(
1033 arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
1034 #endif
1035 }
1036
arm_doubling_high_mult_mve_32x4(const int32x4_t m1,const int32x4_t m2)1037 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
1038 {
1039 return vqrdmulhq_s32(m1, m2);
1040 }
1041
arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend,const int32x4_t exponent)1042 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
1043 {
1044 const int32x4_t shift = -exponent;
1045 const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1046 const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1047 return vrshlq_s32(fixed_up_dividend, shift);
1048 }
1049
arm_requantize_mve_32x4(const int32x4_t val,const int32x4_t multiplier,const int32x4_t shift)1050 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
1051 const int32x4_t multiplier,
1052 const int32x4_t shift)
1053 {
1054 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1055 const int32x4_t right_shift = vminq_s32(vdupq_n_s32(-1), shift);
1056 const int32x4_t left_shift = vqsubq_s32(shift, right_shift);
1057
1058 int32x4_t result = vqdmulhq_s32(vshlq_s32(val, left_shift), multiplier);
1059 result = vrshlq_s32(result, right_shift);
1060
1061 return result;
1062 #else
1063 const int32x4_t zz = vdupq_n_s32(0);
1064 const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
1065
1066 const int32x4_t left_shift = vpselq_s32(shift, zz, p);
1067 const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
1068
1069 return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
1070 right_shift);
1071 #endif
1072 }
1073 #endif
1074
1075 // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
1076
arm_nn_exp_on_negative_values(int32_t val)1077 __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
1078 {
1079 int32_t mask = 0;
1080 int32_t shift = 24;
1081
1082 const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
1083 const int32_t remainder = val_mod_minus_quarter - val;
1084 const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
1085 const int32_t x2 = MUL_SAT(x, x);
1086
1087 int32_t result = 1895147668 +
1088 MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
1089
1090 #define SELECT_IF_NON_ZERO(x) \
1091 { \
1092 mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
1093 result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
1094 }
1095
1096 SELECT_IF_NON_ZERO(1672461947)
1097 SELECT_IF_NON_ZERO(1302514674)
1098 SELECT_IF_NON_ZERO(790015084)
1099 SELECT_IF_NON_ZERO(290630308)
1100 SELECT_IF_NON_ZERO(39332535)
1101 SELECT_IF_NON_ZERO(720401)
1102 SELECT_IF_NON_ZERO(242)
1103
1104 #undef SELECT_IF_NON_ZERO
1105
1106 mask = MASK_IF_ZERO(val);
1107 return SELECT_USING_MASK(mask, NN_Q31_MAX, result);
1108 }
1109
arm_nn_mult_by_power_of_two(const int32_t val,const int32_t exp)1110 __STATIC_FORCEINLINE int32_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
1111 {
1112 const int32_t thresh = ((1 << (31 - exp)) - 1);
1113 int32_t result = val << exp;
1114 result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), NN_Q31_MAX, result);
1115 result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), NN_Q31_MIN, result);
1116 return result;
1117 }
1118
arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)1119 __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
1120 {
1121 const int64_t sum = (int64_t)val + (int64_t)NN_Q31_MAX;
1122 const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
1123 int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
1124
1125 const int32_t shift = (1 << 29);
1126 x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1127 x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1128 x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1129
1130 return MUL_POW2(x, 1);
1131 }
1132
1133 /**
1134 @brief Write 2 s16 elements and post increment pointer.
1135 @param[in] dest_q15 Pointer to pointer that holds address of destination.
1136 @param[in] src_q31 Input value to be written.
1137 */
arm_nn_write_q15x2_ia(int16_t ** dest_q15,int32_t src_q31)1138 __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(int16_t **dest_q15, int32_t src_q31)
1139 {
1140 int32_t val = src_q31;
1141
1142 memcpy(*dest_q15, &val, 4);
1143 *dest_q15 += 2;
1144 }
1145
1146 /**
1147 @brief Write 2 s8 elements and post increment pointer.
1148 @param[in] dst Pointer to pointer that holds address of destination.
1149 @param[in] src Input value to be written.
1150 */
arm_nn_write_s8x2_ia(int8_t ** dst,int16_t src)1151 __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)
1152 {
1153 memcpy(*dst, &src, 2);
1154 *dst += 2;
1155 }
1156
1157 // Support functions for LSTM
1158 /**
1159 * @brief Update LSTM function for an iteration step
1160 *
1161 * param[in] input Input data
1162 * param[in] input_to_input_weight Input to input gate weights
1163 * param[in] input_to_forget_weight Input to forget gate weights
1164 * param[in] input_to_cell_weight Input to cell gate weights
1165 * param[in] input_to_output_weight Input to output weights
1166 * param[in] recurrent_to_input_weight Recurrent signal to input weights
1167 * param[in] recurrent_to_forget_weight Recurrent signal to forget gate weights
1168 * param[in] recurrent_to_cell_weight Recurrent signal to cell gate weighst
1169 * param[in] recurrent_to_output_weight Recurrent signal to output weights
1170 * param[in] lstm LSTM parameters
1171 * param[in] n_batch Batch size
1172 * param[in] n_cell Cell size
1173 * param[in] n_input Input size
1174 * param[in] n_output Output size
1175 * param[out] output_state Output state
1176 * param[out] cell_state Internal state
1177 * param[out] output Output signal
1178 * param[in] *scratch_buffers Struct containing scratch buffers
1179 */
1180 arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
1181 const int8_t *input_to_input_weight,
1182 const int8_t *input_to_forget_weight,
1183 const int8_t *input_to_cell_weight,
1184 const int8_t *input_to_output_weight,
1185 const int8_t *recurrent_to_input_weight,
1186 const int8_t *recurrent_to_forget_weight,
1187 const int8_t *recurrent_to_cell_weight,
1188 const int8_t *recurrent_to_output_weight,
1189 const cmsis_nn_lstm_params *lstm,
1190 const int n_batch,
1191 const int n_cell,
1192 const int n_input,
1193 const int n_output,
1194 int8_t *output_state,
1195 int16_t *cell_state,
1196 int8_t *output,
1197 cmsis_nn_lstm_context *scratch_buffers);
1198
1199 /**
1200 * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
1201 *
1202 * param[in] input Input data
1203 * param[in] input_to_gate_weights Input to gate weights
1204 * param[in] input_to_gate_bias Input to gate weights
1205 * param[in] input_to_gate_scaling Input to gate scaling
1206 * param[in] activation Actival min and max values
1207 * param[in] output_state Output state
1208 * param[in] recurrent_to_gate_weights Recurrent to gate weights
1209 * param[in] recurrent_to_gate_bias Recurrent to gate bias
1210 * param[in] recurrent_to_gate_scaling Recurrent to gate scaling
1211 * param[in] n_batch Batch size
1212 * param[in] n_input Input size
1213 * param[out] n_output Output size
1214 * param[in] activation_type Activation type (sigmoid or tanh)
1215 * param[out] n_cell Cell size
1216 */
1217 void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
1218 const int8_t *input_to_gate_weights,
1219 const int32_t *input_to_gate_bias,
1220 const cmsis_nn_scaling input_to_gate_scaling,
1221 const int8_t *output_state,
1222 const int8_t *recurrent_to_gate_weights,
1223 const int32_t *recurrent_to_gate_bias,
1224 const cmsis_nn_scaling recurrent_to_gate_scaling,
1225 const int32_t n_batch,
1226 const int32_t n_input,
1227 const int32_t n_output,
1228 const int32_t n_cell,
1229 const arm_nn_activation_type activation_type,
1230 int16_t *gate);
1231
1232 /**
1233 * @brief Update cell state for a single LSTM iteration step, int8x8_16 version.
1234 * @param[in] n_block total number of cells for all batches
1235 * @param[in] cell_state_scale Scaling factor of cell state
1236 * @param[in] cell_state Input/output vector, size n_batch*n_cell
1237 * @param[in] input_gate Input vector of size n_block
1238 * @param[in] forget_gate Input/scratch vector of size n_block, always modified
1239 * @param[in] cell_gate Input vector of size, n_block
1240 */
1241 void arm_nn_lstm_update_cell_state_s16(const int32_t n_block,
1242 const int32_t cell_state_scale,
1243 int16_t *cell_state,
1244 const int16_t *input_gate,
1245 const int16_t *forget_gate,
1246 const int16_t *cell_gate);
1247
1248 /**
1249 * @brief Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
1250 *
1251 * @param[in] n_batch The number of distinct vectors in each array
1252 * @param[in] n_cell Number of cells
1253 * @param[in,out] cell_state Cell state, size n_batch*n_cell
1254 * @param[in] cell_state_scale Scaling of cell_state
1255 * @param[in] output_gate Output gate
1256 * @param[in] hidden_scale Effective scaling of cell_state .* output_gate
1257 * @param[in] hidden_offset Zero point for cell_state .* output_gate
1258 * @param[out] output_state Output state
1259 * @param[in] cell_gate_scratch Scratch buffer
1260 */
1261 void arm_nn_lstm_update_output_s8_s16(const int n_batch,
1262 const int n_cell,
1263 int16_t *cell_state,
1264 const int32_t cell_state_scale,
1265 const int16_t *output_gate,
1266 const cmsis_nn_scaling hidden_scale,
1267 const int32_t hidden_offset,
1268 int8_t *output_state,
1269 int16_t *cell_gate_scratch);
1270
1271 /**
1272 * @brief The result of the multiplication is accumulated to the passed result buffer.
1273 * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
1274 * from each other).
1275 *
1276 * @param[in] lhs_in Batched vector
1277 * @param[in] rhs_in Weights - input matrix (H(Rows)xW(Columns))
1278 * @param[in] bias Bias vector
1279 * @param[out] dst Output
1280 * @param[in] dst_offset Output offset
1281 * @param[in] dst_multiplier Multiplier for quantization
1282 * @param[in] dst_shift Shift for quantization
1283 * @param[in] rhs_cols Vector/matarix column length
1284 * @param[in] rhs_rows Row count of matrix
1285 * @param[in] batch Batch size
1286 */
1287 void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
1288 const int8_t *rhs_in,
1289 const int32_t *bias,
1290 int16_t *dst,
1291 const int32_t dst_offset,
1292 const int32_t dst_multiplier,
1293 const int32_t dst_shift,
1294 const int32_t rhs_cols,
1295 const int32_t rhs_rows,
1296 const int32_t batch);
1297
1298 /**
1299 * @brief s16 elementwise multiplication with s8 output
1300 * @param[in] input_1_vect pointer to input vector 1
1301 * @param[in] input_2_vect pointer to input vector 2
1302 * @param[in,out] output pointer to output vector
1303 * @param[in] out_offset output offset
1304 * @param[in] out_mult output multiplier
1305 * @param[in] out_shift output shift
1306 * @param[in] block_size number of samples
1307 * @return The function returns ARM_CMSIS_NN_SUCCESS
1308 *
1309 * @details Supported framework: TensorFlow Lite micro
1310 */
1311 arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
1312 const int16_t *input_2_vect,
1313 int8_t *output,
1314 const int32_t out_offset,
1315 const int32_t out_mult,
1316 const int32_t out_shift,
1317 const int32_t block_size);
1318
1319 #ifdef __cplusplus
1320 }
1321 #endif
1322
1323 #endif
1324