1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
17
18 #include "tensorflow/lite/micro/micro_op_resolver.h"
19 #include "tensorflow/lite/micro/testing/micro_test.h"
20
21 namespace tflite {
22 namespace {
MockInit(TfLiteContext * context,const char * buffer,size_t length)23 void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
24 // Do nothing.
25 return nullptr;
26 }
27
MockFree(TfLiteContext * context,void * buffer)28 void MockFree(TfLiteContext* context, void* buffer) {
29 // Do nothing.
30 }
31
MockPrepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
33 return kTfLiteOk;
34 }
35
MockInvoke(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
37 return kTfLiteOk;
38 }
39
40 class MockErrorReporter : public ErrorReporter {
41 public:
MockErrorReporter()42 MockErrorReporter() : has_been_called_(false) {}
Report(const char * format,va_list args)43 int Report(const char* format, va_list args) override {
44 has_been_called_ = true;
45 return 0;
46 };
47
HasBeenCalled()48 bool HasBeenCalled() { return has_been_called_; }
49
ResetState()50 void ResetState() { has_been_called_ = false; }
51
52 private:
53 bool has_been_called_;
54 TF_LITE_REMOVE_VIRTUAL_DELETE
55 };
56
57 } // namespace
58 } // namespace tflite
59
60 TF_LITE_MICRO_TESTS_BEGIN
61
TF_LITE_MICRO_TEST(TestOperations)62 TF_LITE_MICRO_TEST(TestOperations) {
63 using tflite::BuiltinOperator_CONV_2D;
64 using tflite::BuiltinOperator_RELU;
65 using tflite::MicroMutableOpResolver;
66 using tflite::OpResolver;
67
68 static TfLiteRegistration r = {};
69 r.init = tflite::MockInit;
70 r.free = tflite::MockFree;
71 r.prepare = tflite::MockPrepare;
72 r.invoke = tflite::MockInvoke;
73
74 MicroMutableOpResolver<1> micro_op_resolver;
75 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
76 micro_op_resolver.AddCustom("mock_custom", &r));
77
78 // Only one AddCustom per operator should return kTfLiteOk.
79 TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
80 micro_op_resolver.AddCustom("mock_custom", &r));
81
82 tflite::MicroOpResolver* resolver = µ_op_resolver;
83
84 TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1),
85 micro_op_resolver.GetRegistrationLength());
86
87 const TfLiteRegistration* registration =
88 resolver->FindOp(BuiltinOperator_RELU);
89 TF_LITE_MICRO_EXPECT(nullptr == registration);
90
91 registration = resolver->FindOp("mock_custom");
92 TF_LITE_MICRO_EXPECT(nullptr != registration);
93 TF_LITE_MICRO_EXPECT(nullptr == registration->init(nullptr, nullptr, 0));
94 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
95 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
96
97 registration = resolver->FindOp("nonexistent_custom");
98 TF_LITE_MICRO_EXPECT(nullptr == registration);
99 }
100
TF_LITE_MICRO_TEST(TestErrorReporting)101 TF_LITE_MICRO_TEST(TestErrorReporting) {
102 using tflite::BuiltinOperator_CONV_2D;
103 using tflite::BuiltinOperator_RELU;
104 using tflite::MicroMutableOpResolver;
105
106 static TfLiteRegistration r = {};
107 r.init = tflite::MockInit;
108 r.free = tflite::MockFree;
109 r.prepare = tflite::MockPrepare;
110 r.invoke = tflite::MockInvoke;
111
112 tflite::MockErrorReporter mock_reporter;
113 MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
114 TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
115 mock_reporter.ResetState();
116
117 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
118 micro_op_resolver.AddCustom("mock_custom_0", &r));
119 TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
120 mock_reporter.ResetState();
121
122 // Attempting to Add more operators than the class template parameter for
123 // MicroMutableOpResolver should result in errors.
124 TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddRelu());
125
126 TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
127 micro_op_resolver.AddCustom("mock_custom_1", &r));
128 TF_LITE_MICRO_EXPECT_EQ(true, mock_reporter.HasBeenCalled());
129 mock_reporter.ResetState();
130 }
131
132 TF_LITE_MICRO_TESTS_END
133