1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
17 
18 #include "tensorflow/lite/micro/micro_error_reporter.h"
19 #include "tensorflow/lite/micro/simple_memory_allocator.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21 
22 namespace tflite {
23 namespace micro {
24 
25 namespace {
26 constexpr size_t kBufferAlignment = 16;
27 }  // namespace
28 
29 // TODO(b/161841696): Consider moving away from global arena buffers:
30 constexpr int KernelRunner::kNumScratchBuffers_;
31 constexpr int KernelRunner::kKernelRunnerBufferSize_;
32 uint8_t KernelRunner::kKernelRunnerBuffer_[];
33 
KernelRunner(const TfLiteRegistration & registration,TfLiteTensor * tensors,int tensors_size,TfLiteIntArray * inputs,TfLiteIntArray * outputs,void * builtin_data)34 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
35                            TfLiteTensor* tensors, int tensors_size,
36                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
37                            void* builtin_data)
38     : allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
39                                                kKernelRunnerBuffer_,
40                                                kKernelRunnerBufferSize_)),
41       registration_(registration),
42       tensors_(tensors),
43       mock_micro_graph_(allocator_) {
44   // Prepare TfLiteContext:
45   context_.impl_ = static_cast<void*>(this);
46   context_.ReportError = ReportOpError;
47   context_.recommended_num_threads = 1;
48   context_.GetTensor = GetTensor;
49   context_.GetEvalTensor = GetEvalTensor;
50   context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
51   context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
52   context_.GetScratchBuffer = GetScratchBuffer;
53   context_.GetExecutionPlan = GetGraph;
54   context_.recommended_num_threads = 0;
55 
56   // Prepare TfLiteNode:
57   node_.inputs = inputs;
58   node_.outputs = outputs;
59   node_.builtin_data = builtin_data;
60 }
61 
InitAndPrepare(const char * init_data,size_t length)62 TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
63                                           size_t length) {
64   if (registration_.init) {
65     node_.user_data = registration_.init(&context_, init_data, length);
66   }
67   if (registration_.prepare) {
68     TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
69   }
70   return kTfLiteOk;
71 }
72 
Invoke()73 TfLiteStatus KernelRunner::Invoke() {
74   if (registration_.invoke == nullptr) {
75     MicroPrintf("TfLiteRegistration missing invoke function pointer!");
76     return kTfLiteError;
77   }
78   return registration_.invoke(&context_, &node_);
79 }
80 
GetTensor(const struct TfLiteContext * context,int tensor_index)81 TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
82                                       int tensor_index) {
83   TFLITE_DCHECK(context != nullptr);
84   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
85   TFLITE_DCHECK(runner != nullptr);
86 
87   return &runner->tensors_[tensor_index];
88 }
89 
GetEvalTensor(const struct TfLiteContext * context,int tensor_index)90 TfLiteEvalTensor* KernelRunner::GetEvalTensor(
91     const struct TfLiteContext* context, int tensor_index) {
92   TFLITE_DCHECK(context != nullptr);
93   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
94   TFLITE_DCHECK(runner != nullptr);
95 
96   TfLiteEvalTensor* eval_tensor =
97       reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
98           sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
99   TFLITE_DCHECK(eval_tensor != nullptr);
100 
101   // In unit tests, the TfLiteTensor pointer contains the source of truth for
102   // buffers and values:
103   eval_tensor->data = runner->tensors_[tensor_index].data;
104   eval_tensor->dims = runner->tensors_[tensor_index].dims;
105   eval_tensor->type = runner->tensors_[tensor_index].type;
106   return eval_tensor;
107 }
108 
AllocatePersistentBuffer(TfLiteContext * context,size_t bytes)109 void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
110                                              size_t bytes) {
111   TFLITE_DCHECK(context != nullptr);
112   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
113   TFLITE_DCHECK(runner != nullptr);
114 
115   return runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
116 }
117 
RequestScratchBufferInArena(TfLiteContext * context,size_t bytes,int * buffer_index)118 TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
119                                                        size_t bytes,
120                                                        int* buffer_index) {
121   TFLITE_DCHECK(context != nullptr);
122   TFLITE_DCHECK(buffer_index != nullptr);
123 
124   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
125   TFLITE_DCHECK(runner != nullptr);
126 
127   if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
128     MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
129                 kNumScratchBuffers_);
130     return kTfLiteError;
131   }
132 
133   // For tests, we allocate scratch buffers from the tail and keep them around
134   // for the lifetime of model. This means that the arena size in the tests will
135   // be more than what we would have if the scratch buffers could share memory.
136   runner->scratch_buffers_[runner->scratch_buffer_count_] =
137       runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
138   TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
139                 nullptr);
140 
141   *buffer_index = runner->scratch_buffer_count_++;
142   return kTfLiteOk;
143 }
144 
GetScratchBuffer(TfLiteContext * context,int buffer_index)145 void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
146   TFLITE_DCHECK(context != nullptr);
147   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
148   TFLITE_DCHECK(runner != nullptr);
149 
150   TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
151   if (buffer_index >= runner->scratch_buffer_count_) {
152     return nullptr;
153   }
154   return runner->scratch_buffers_[buffer_index];
155 }
156 
ReportOpError(struct TfLiteContext * context,const char * format,...)157 void KernelRunner::ReportOpError(struct TfLiteContext* context,
158                                  const char* format, ...) {
159   va_list args;
160   va_start(args, format);
161   GetMicroErrorReporter()->Report(format, args);
162   va_end(args);
163 }
164 
GetGraph(struct TfLiteContext * context,TfLiteIntArray ** args)165 TfLiteStatus KernelRunner::GetGraph(struct TfLiteContext* context,
166                                     TfLiteIntArray** args) {
167   TFLITE_DCHECK(context != nullptr);
168   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
169   TFLITE_DCHECK(runner != nullptr);
170   // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
171   *args = reinterpret_cast<TfLiteIntArray*>(runner->GetMockGraph());
172   return kTfLiteOk;
173 }
174 
175 }  // namespace micro
176 }  // namespace tflite
177