1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_
17 #define TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_
18 
19 #include "mli_api.h"  // NOLINT
20 #include "mli_interface.h"
21 #include "tensorflow/lite/c/common.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace micro {
26 
27 /**
28  * @brief Function to allocate scratch buffers for the convolution tensors
29  *
30  * @detail This function will update the data pointers in the 4 tensors with
31  * pointers to scratch buffers in fast local memory.
32  *
33  * @param context  [I] pointer to TfLite context (needed for error handling)
34  * @param in [IO] pointer to the input tensor
35  * @param weights [IO] pointer to the weights tensor
36  * @param bias [IO] pointer to the bias tensor
37  * @param output [IO] pointer to the output tensor
38  *
39  * @return Tf Lite status code
40  */
41 TfLiteStatus get_arc_scratch_buffer_for_conv_tensors(
42     TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* weights,
43     MliTensorInterface* bias, MliTensorInterface* out);
44 
45 /**
46  * @brief Function to allocate scratch buffers for pooling kernels with only
47  * input and output buffers
48  *
49  * @detail This function will update the data pointers in the 2 tensors with
50  * pointers to scratch buffers in fast local memory.
51  *
52  * @param context  [I] pointer to TfLite context (needed for error handling)
53  * @param in [IO] pointer to the input tensor
54  * @param output [IO] pointer to the output tensor
55  *
56  * @return Tf Lite status code
57  */
58 TfLiteStatus get_arc_scratch_buffer_for_pooling_tensors(
59     TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* out);
60 
61 /**
62  * @brief Function to allocate scratch buffers for the fully connect tensors
63  *
64  * @detail This function will update the data pointers in the 4 tensors with
65  * pointers to scratch buffers in fast local memory.
66  *
67  * @param context  [I] pointer to TfLite context (needed for error handling)
68  * @param in [IO] pointer to the input tensor
69  * @param weights [IO] pointer to the weights tensor
70  * @param bias [IO] pointer to the bias tensor
71  * @param output [IO] pointer to the output tensor
72  *
73  * @return Tf Lite status code
74  */
75 TfLiteStatus get_arc_scratch_buffer_for_fully_connect_tensors(
76     TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* weights,
77     MliTensorInterface* bias, MliTensorInterface* out);
78 
79 /**
80  * @brief Function to calculate slice size for io tensors
81  *
82  * @detail This function will calculate the slice size in the height dimension
83  * for input and output tensors. it takes into account the kernel size and the
84  * padding. the function will look at the capacity filed in the in and out
85  * tensor to determine the available buffersize.
86  *
87  * @param in [I] pointer to the input tensor
88  * @param out [I] pointer to the output tensor
89  * @param kernelHeight [I] size of the kernel in height dimension
90  * @param strideHeight [I] input stride in height dimension
91  * @param padding_top [I] number of lines with zeros at the top
92  * @param padding_bot [I] number of lines with zeros at the bottom
93  * @param inSliceHeight [O] slice size in height dimension for the input tensor
94  * @param outSliceHeight [O] slice size in height dimension for the output
95  * tensor
96  *
97  * @return Tf Lite status code
98  */
99 TfLiteStatus arc_scratch_buffer_calc_slice_size_io(
100     const MliTensorInterface* in, const MliTensorInterface* out,
101     const int kernelHeight, const int strideHeight, const int padding_top,
102     const int padding_bot, int* in_slice_height, int* out_slice_height);
103 
104 /**
105  * @brief Function to calculate slice size for weight slicing
106  *
107  * @detail This function will calculate the slice size in the output channel
108  * dimension for weight and bias tensors. the function will look at the capacity
109  * filed in the weights and bias tensor to determine the available buffersize.
110  *
111  * @param weights [I] pointer to the input tensor
112  * @param bias [I] pointer to the output tensor
113  * @param weightOutChDimension [I] dimension of the output channels in the
114  * weights tensor
115  * @param sliceChannels [O] slice size in output channel dimension
116  *
117  * @return Tf Lite status code
118  */
119 TfLiteStatus arc_scratch_buffer_calc_slice_size_weights(
120     const MliTensorInterface* weights, const MliTensorInterface* bias,
121     const int weight_out_ch_dimension, int* slice_channels);
122 
123 }  // namespace micro
124 }  // namespace ops
125 }  // namespace tflite
126 
127 #endif  // TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_
128