1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_RUNNER_H_
17 #define TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_RUNNER_H_
18 
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/micro/mock_micro_graph.h"
22 #include "tensorflow/lite/micro/simple_memory_allocator.h"
23 
24 namespace tflite {
25 namespace micro {
26 
27 // Helper class to perform a simulated kernel (i.e. TfLiteRegistration)
28 // lifecycle (init, prepare, invoke). All internal allocations are handled by
29 // this class. Simply pass in the registration, list of required tensors, inputs
30 // array, outputs array, and any pre-builtin data. Calling Invoke() will
31 // automatically walk the kernel and outputs will be ready on the TfLiteTensor
32 // output provided during construction.
33 class KernelRunner {
34  public:
35   KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
36                int tensors_size, TfLiteIntArray* inputs,
37                TfLiteIntArray* outputs, void* builtin_data);
38 
39   // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
40   // exceptions will be DebugLog'd and returned as a status code.
41   TfLiteStatus InitAndPrepare(const char* init_data = nullptr,
42                               size_t length = 0);
43 
44   // Calls init, prepare, and invoke on a given TfLiteRegistration pointer.
45   // After successful invoke, results will be available in the output tensor as
46   // passed into the constructor of this class.
47   TfLiteStatus Invoke();
48 
49   // Returns a pointer to the internal MockMicroGraph which KernelRunner uses
50   // to stub out MicroGraph methods and track invocations on each subgraph.
GetMockGraph()51   MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; }
52 
53  protected:
54   static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
55                                  int tensor_index);
56   static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
57                                          int tensor_index);
58   static void* AllocatePersistentBuffer(TfLiteContext* context, size_t bytes);
59   static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context,
60                                                   size_t bytes,
61                                                   int* buffer_index);
62   static void* GetScratchBuffer(TfLiteContext* context, int buffer_index);
63   static void ReportOpError(struct TfLiteContext* context, const char* format,
64                             ...);
65   // This method matches GetExecutionPlan from TfLiteContext since TFLM reuses
66   // this method to get the MicroGraph from an operator context.
67   // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
68   static TfLiteStatus GetGraph(struct TfLiteContext* context,
69                                TfLiteIntArray** args);
70 
71  private:
72   static constexpr int kNumScratchBuffers_ = 12;
73 
74   static constexpr int kKernelRunnerBufferSize_ = 10000;
75   static uint8_t kKernelRunnerBuffer_[kKernelRunnerBufferSize_];
76 
77   SimpleMemoryAllocator* allocator_ = nullptr;
78   const TfLiteRegistration& registration_;
79   TfLiteTensor* tensors_ = nullptr;
80   MockMicroGraph mock_micro_graph_;
81 
82   TfLiteContext context_ = {};
83   TfLiteNode node_ = {};
84 
85   int scratch_buffer_count_ = 0;
86   uint8_t* scratch_buffers_[kNumScratchBuffers_];
87 };
88 
89 }  // namespace micro
90 }  // namespace tflite
91 
92 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_RUNNER_H_
93