1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
17 
18 #include "tensorflow/lite/micro/micro_op_resolver.h"
19 #include "tensorflow/lite/micro/testing/micro_test.h"
20 
21 namespace tflite {
22 namespace {
MockInit(TfLiteContext * context,const char * buffer,size_t length)23 void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
24   // Do nothing.
25   return nullptr;
26 }
27 
MockFree(TfLiteContext * context,void * buffer)28 void MockFree(TfLiteContext* context, void* buffer) {
29   // Do nothing.
30 }
31 
MockPrepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
33   return kTfLiteOk;
34 }
35 
MockInvoke(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
37   return kTfLiteOk;
38 }
39 
40 class MockErrorReporter : public ErrorReporter {
41  public:
MockErrorReporter()42   MockErrorReporter() : has_been_called_(false) {}
Report(const char * format,va_list args)43   int Report(const char* format, va_list args) override {
44     has_been_called_ = true;
45     return 0;
46   };
47 
HasBeenCalled()48   bool HasBeenCalled() { return has_been_called_; }
49 
ResetState()50   void ResetState() { has_been_called_ = false; }
51 
52  private:
53   bool has_been_called_;
54   TF_LITE_REMOVE_VIRTUAL_DELETE
55 };
56 
57 }  // namespace
58 }  // namespace tflite
59 
60 TF_LITE_MICRO_TESTS_BEGIN
61 
TF_LITE_MICRO_TEST(TestOperations)62 TF_LITE_MICRO_TEST(TestOperations) {
63   using tflite::BuiltinOperator_CONV_2D;
64   using tflite::BuiltinOperator_RELU;
65   using tflite::MicroMutableOpResolver;
66   using tflite::OpResolver;
67 
68   static TfLiteRegistration r = {};
69   r.init = tflite::MockInit;
70   r.free = tflite::MockFree;
71   r.prepare = tflite::MockPrepare;
72   r.invoke = tflite::MockInvoke;
73 
74   MicroMutableOpResolver<1> micro_op_resolver;
75   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
76                           micro_op_resolver.AddCustom("mock_custom", &r));
77 
78   // Only one AddCustom per operator should return kTfLiteOk.
79   TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
80                           micro_op_resolver.AddCustom("mock_custom", &r));
81 
82   tflite::MicroOpResolver* resolver = &micro_op_resolver;
83 
84   TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1),
85                           micro_op_resolver.GetRegistrationLength());
86 
87   const TfLiteRegistration* registration =
88       resolver->FindOp(BuiltinOperator_RELU);
89   TF_LITE_MICRO_EXPECT(nullptr == registration);
90 
91   registration = resolver->FindOp("mock_custom");
92   TF_LITE_MICRO_EXPECT(nullptr != registration);
93   TF_LITE_MICRO_EXPECT(nullptr == registration->init(nullptr, nullptr, 0));
94   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
95   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
96 
97   registration = resolver->FindOp("nonexistent_custom");
98   TF_LITE_MICRO_EXPECT(nullptr == registration);
99 }
100 
TF_LITE_MICRO_TEST(TestErrorReporting)101 TF_LITE_MICRO_TEST(TestErrorReporting) {
102   using tflite::BuiltinOperator_CONV_2D;
103   using tflite::BuiltinOperator_RELU;
104   using tflite::MicroMutableOpResolver;
105 
106   static TfLiteRegistration r = {};
107   r.init = tflite::MockInit;
108   r.free = tflite::MockFree;
109   r.prepare = tflite::MockPrepare;
110   r.invoke = tflite::MockInvoke;
111 
112   tflite::MockErrorReporter mock_reporter;
113   MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
114   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
115   mock_reporter.ResetState();
116 
117   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
118                           micro_op_resolver.AddCustom("mock_custom_0", &r));
119   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
120   mock_reporter.ResetState();
121 
122   // Attempting to Add more operators than the class template parameter for
123   // MicroMutableOpResolver should result in errors.
124   TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddRelu());
125 
126   TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
127                           micro_op_resolver.AddCustom("mock_custom_1", &r));
128   TF_LITE_MICRO_EXPECT_EQ(true, mock_reporter.HasBeenCalled());
129   mock_reporter.ResetState();
130 }
131 
132 TF_LITE_MICRO_TESTS_END
133