1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ 17 #define TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ 18 19 #include "mli_api.h" // NOLINT 20 #include "mli_interface.h" 21 #include "tensorflow/lite/c/common.h" 22 23 namespace tflite { 24 namespace ops { 25 namespace micro { 26 27 /** 28 * @brief Function to allocate scratch buffers for the convolution tensors 29 * 30 * @detail This function will update the data pointers in the 4 tensors with 31 * pointers to scratch buffers in fast local memory. 32 * 33 * @param context [I] pointer to TfLite context (needed for error handling) 34 * @param in [IO] pointer to the input tensor 35 * @param weights [IO] pointer to the weights tensor 36 * @param bias [IO] pointer to the bias tensor 37 * @param output [IO] pointer to the output tensor 38 * 39 * @return Tf Lite status code 40 */ 41 TfLiteStatus get_arc_scratch_buffer_for_conv_tensors( 42 TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* weights, 43 MliTensorInterface* bias, MliTensorInterface* out); 44 45 /** 46 * @brief Function to allocate scratch buffers for pooling kernels with only 47 * input and output buffers 48 * 49 * @detail This function will update the data pointers in the 2 tensors with 50 * pointers to scratch buffers in fast local memory. 51 * 52 * @param context [I] pointer to TfLite context (needed for error handling) 53 * @param in [IO] pointer to the input tensor 54 * @param output [IO] pointer to the output tensor 55 * 56 * @return Tf Lite status code 57 */ 58 TfLiteStatus get_arc_scratch_buffer_for_pooling_tensors( 59 TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* out); 60 61 /** 62 * @brief Function to allocate scratch buffers for the fully connect tensors 63 * 64 * @detail This function will update the data pointers in the 4 tensors with 65 * pointers to scratch buffers in fast local memory. 66 * 67 * @param context [I] pointer to TfLite context (needed for error handling) 68 * @param in [IO] pointer to the input tensor 69 * @param weights [IO] pointer to the weights tensor 70 * @param bias [IO] pointer to the bias tensor 71 * @param output [IO] pointer to the output tensor 72 * 73 * @return Tf Lite status code 74 */ 75 TfLiteStatus get_arc_scratch_buffer_for_fully_connect_tensors( 76 TfLiteContext* context, MliTensorInterface* in, MliTensorInterface* weights, 77 MliTensorInterface* bias, MliTensorInterface* out); 78 79 /** 80 * @brief Function to calculate slice size for io tensors 81 * 82 * @detail This function will calculate the slice size in the height dimension 83 * for input and output tensors. it takes into account the kernel size and the 84 * padding. the function will look at the capacity filed in the in and out 85 * tensor to determine the available buffersize. 86 * 87 * @param in [I] pointer to the input tensor 88 * @param out [I] pointer to the output tensor 89 * @param kernelHeight [I] size of the kernel in height dimension 90 * @param strideHeight [I] input stride in height dimension 91 * @param padding_top [I] number of lines with zeros at the top 92 * @param padding_bot [I] number of lines with zeros at the bottom 93 * @param inSliceHeight [O] slice size in height dimension for the input tensor 94 * @param outSliceHeight [O] slice size in height dimension for the output 95 * tensor 96 * 97 * @return Tf Lite status code 98 */ 99 TfLiteStatus arc_scratch_buffer_calc_slice_size_io( 100 const MliTensorInterface* in, const MliTensorInterface* out, 101 const int kernelHeight, const int strideHeight, const int padding_top, 102 const int padding_bot, int* in_slice_height, int* out_slice_height); 103 104 /** 105 * @brief Function to calculate slice size for weight slicing 106 * 107 * @detail This function will calculate the slice size in the output channel 108 * dimension for weight and bias tensors. the function will look at the capacity 109 * filed in the weights and bias tensor to determine the available buffersize. 110 * 111 * @param weights [I] pointer to the input tensor 112 * @param bias [I] pointer to the output tensor 113 * @param weightOutChDimension [I] dimension of the output channels in the 114 * weights tensor 115 * @param sliceChannels [O] slice size in output channel dimension 116 * 117 * @return Tf Lite status code 118 */ 119 TfLiteStatus arc_scratch_buffer_calc_slice_size_weights( 120 const MliTensorInterface* weights, const MliTensorInterface* bias, 121 const int weight_out_ch_dimension, int* slice_channels); 122 123 } // namespace micro 124 } // namespace ops 125 } // namespace tflite 126 127 #endif // TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ 128