1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
17
18 #include "tensorflow/lite/micro/micro_error_reporter.h"
19 #include "tensorflow/lite/micro/simple_memory_allocator.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21
22 namespace tflite {
23 namespace micro {
24
25 namespace {
26 constexpr size_t kBufferAlignment = 16;
27 } // namespace
28
29 // TODO(b/161841696): Consider moving away from global arena buffers:
30 constexpr int KernelRunner::kNumScratchBuffers_;
31 constexpr int KernelRunner::kKernelRunnerBufferSize_;
32 uint8_t KernelRunner::kKernelRunnerBuffer_[];
33
KernelRunner(const TfLiteRegistration & registration,TfLiteTensor * tensors,int tensors_size,TfLiteIntArray * inputs,TfLiteIntArray * outputs,void * builtin_data)34 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
35 TfLiteTensor* tensors, int tensors_size,
36 TfLiteIntArray* inputs, TfLiteIntArray* outputs,
37 void* builtin_data)
38 : allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
39 kKernelRunnerBuffer_,
40 kKernelRunnerBufferSize_)),
41 registration_(registration),
42 tensors_(tensors),
43 mock_micro_graph_(allocator_) {
44 // Prepare TfLiteContext:
45 context_.impl_ = static_cast<void*>(this);
46 context_.ReportError = ReportOpError;
47 context_.recommended_num_threads = 1;
48 context_.GetTensor = GetTensor;
49 context_.GetEvalTensor = GetEvalTensor;
50 context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
51 context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
52 context_.GetScratchBuffer = GetScratchBuffer;
53 context_.GetExecutionPlan = GetGraph;
54 context_.recommended_num_threads = 0;
55
56 // Prepare TfLiteNode:
57 node_.inputs = inputs;
58 node_.outputs = outputs;
59 node_.builtin_data = builtin_data;
60 }
61
InitAndPrepare(const char * init_data,size_t length)62 TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
63 size_t length) {
64 if (registration_.init) {
65 node_.user_data = registration_.init(&context_, init_data, length);
66 }
67 if (registration_.prepare) {
68 TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
69 }
70 return kTfLiteOk;
71 }
72
Invoke()73 TfLiteStatus KernelRunner::Invoke() {
74 if (registration_.invoke == nullptr) {
75 MicroPrintf("TfLiteRegistration missing invoke function pointer!");
76 return kTfLiteError;
77 }
78 return registration_.invoke(&context_, &node_);
79 }
80
GetTensor(const struct TfLiteContext * context,int tensor_index)81 TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
82 int tensor_index) {
83 TFLITE_DCHECK(context != nullptr);
84 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
85 TFLITE_DCHECK(runner != nullptr);
86
87 return &runner->tensors_[tensor_index];
88 }
89
GetEvalTensor(const struct TfLiteContext * context,int tensor_index)90 TfLiteEvalTensor* KernelRunner::GetEvalTensor(
91 const struct TfLiteContext* context, int tensor_index) {
92 TFLITE_DCHECK(context != nullptr);
93 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
94 TFLITE_DCHECK(runner != nullptr);
95
96 TfLiteEvalTensor* eval_tensor =
97 reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
98 sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
99 TFLITE_DCHECK(eval_tensor != nullptr);
100
101 // In unit tests, the TfLiteTensor pointer contains the source of truth for
102 // buffers and values:
103 eval_tensor->data = runner->tensors_[tensor_index].data;
104 eval_tensor->dims = runner->tensors_[tensor_index].dims;
105 eval_tensor->type = runner->tensors_[tensor_index].type;
106 return eval_tensor;
107 }
108
AllocatePersistentBuffer(TfLiteContext * context,size_t bytes)109 void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
110 size_t bytes) {
111 TFLITE_DCHECK(context != nullptr);
112 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
113 TFLITE_DCHECK(runner != nullptr);
114
115 return runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
116 }
117
RequestScratchBufferInArena(TfLiteContext * context,size_t bytes,int * buffer_index)118 TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
119 size_t bytes,
120 int* buffer_index) {
121 TFLITE_DCHECK(context != nullptr);
122 TFLITE_DCHECK(buffer_index != nullptr);
123
124 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
125 TFLITE_DCHECK(runner != nullptr);
126
127 if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
128 MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
129 kNumScratchBuffers_);
130 return kTfLiteError;
131 }
132
133 // For tests, we allocate scratch buffers from the tail and keep them around
134 // for the lifetime of model. This means that the arena size in the tests will
135 // be more than what we would have if the scratch buffers could share memory.
136 runner->scratch_buffers_[runner->scratch_buffer_count_] =
137 runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
138 TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
139 nullptr);
140
141 *buffer_index = runner->scratch_buffer_count_++;
142 return kTfLiteOk;
143 }
144
GetScratchBuffer(TfLiteContext * context,int buffer_index)145 void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
146 TFLITE_DCHECK(context != nullptr);
147 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
148 TFLITE_DCHECK(runner != nullptr);
149
150 TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
151 if (buffer_index >= runner->scratch_buffer_count_) {
152 return nullptr;
153 }
154 return runner->scratch_buffers_[buffer_index];
155 }
156
ReportOpError(struct TfLiteContext * context,const char * format,...)157 void KernelRunner::ReportOpError(struct TfLiteContext* context,
158 const char* format, ...) {
159 va_list args;
160 va_start(args, format);
161 GetMicroErrorReporter()->Report(format, args);
162 va_end(args);
163 }
164
GetGraph(struct TfLiteContext * context,TfLiteIntArray ** args)165 TfLiteStatus KernelRunner::GetGraph(struct TfLiteContext* context,
166 TfLiteIntArray** args) {
167 TFLITE_DCHECK(context != nullptr);
168 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
169 TFLITE_DCHECK(runner != nullptr);
170 // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
171 *args = reinterpret_cast<TfLiteIntArray*>(runner->GetMockGraph());
172 return kTfLiteOk;
173 }
174
175 } // namespace micro
176 } // namespace tflite
177