1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
17 #define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
18 
19 #include <cstdint>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h"
23 #include "tensorflow/lite/micro/micro_error_reporter.h"
24 
25 // Partial implementation of std::dequeue, just providing the functionality
26 // that's needed to keep a record of previous neural network results over a
27 // short time period, so they can be averaged together to produce a more
28 // accurate overall prediction. This doesn't use any dynamic memory allocation
29 // so it's a better fit for microcontroller applications, but this does mean
30 // there are hard limits on the number of results it can store.
31 class PreviousResultsQueue {
32  public:
PreviousResultsQueue(tflite::ErrorReporter * error_reporter)33   PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
34       : error_reporter_(error_reporter), front_index_(0), size_(0) {}
35 
36   // Data structure that holds an inference result, and the time when it
37   // was recorded.
38   struct Result {
ResultResult39     Result() : time_(0), scores() {}
ResultResult40     Result(int32_t time, int8_t* input_scores) : time_(time) {
41       for (int i = 0; i < kCategoryCount; ++i) {
42         scores[i] = input_scores[i];
43       }
44     }
45     int32_t time_;
46     int8_t scores[kCategoryCount];
47   };
48 
size()49   int size() { return size_; }
empty()50   bool empty() { return size_ == 0; }
front()51   Result& front() { return results_[front_index_]; }
back()52   Result& back() {
53     int back_index = front_index_ + (size_ - 1);
54     if (back_index >= kMaxResults) {
55       back_index -= kMaxResults;
56     }
57     return results_[back_index];
58   }
59 
push_back(const Result & entry)60   void push_back(const Result& entry) {
61     if (size() >= kMaxResults) {
62       TF_LITE_REPORT_ERROR(
63           error_reporter_,
64           "Couldn't push_back latest result, too many already!");
65       return;
66     }
67     size_ += 1;
68     back() = entry;
69   }
70 
pop_front()71   Result pop_front() {
72     if (size() <= 0) {
73       TF_LITE_REPORT_ERROR(error_reporter_,
74                            "Couldn't pop_front result, none present!");
75       return Result();
76     }
77     Result result = front();
78     front_index_ += 1;
79     if (front_index_ >= kMaxResults) {
80       front_index_ = 0;
81     }
82     size_ -= 1;
83     return result;
84   }
85 
86   // Most of the functions are duplicates of dequeue containers, but this
87   // is a helper that makes it easy to iterate through the contents of the
88   // queue.
from_front(int offset)89   Result& from_front(int offset) {
90     if ((offset < 0) || (offset >= size_)) {
91       TF_LITE_REPORT_ERROR(error_reporter_,
92                            "Attempt to read beyond the end of the queue!");
93       offset = size_ - 1;
94     }
95     int index = front_index_ + offset;
96     if (index >= kMaxResults) {
97       index -= kMaxResults;
98     }
99     return results_[index];
100   }
101 
102  private:
103   tflite::ErrorReporter* error_reporter_;
104   static constexpr int kMaxResults = 50;
105   Result results_[kMaxResults];
106 
107   int front_index_;
108   int size_;
109 };
110 
111 // This class is designed to apply a very primitive decoding model on top of the
112 // instantaneous results from running an audio recognition model on a single
113 // window of samples. It applies smoothing over time so that noisy individual
114 // label scores are averaged, increasing the confidence that apparent matches
115 // are real.
116 // To use it, you should create a class object with the configuration you
117 // want, and then feed results from running a TensorFlow model into the
118 // processing method. The timestamp for each subsequent call should be
119 // increasing from the previous, since the class is designed to process a stream
120 // of data over time.
121 class RecognizeCommands {
122  public:
123   // labels should be a list of the strings associated with each one-hot score.
124   // The window duration controls the smoothing. Longer durations will give a
125   // higher confidence that the results are correct, but may miss some commands.
126   // The detection threshold has a similar effect, with high values increasing
127   // the precision at the cost of recall. The minimum count controls how many
128   // results need to be in the averaging window before it's seen as a reliable
129   // average. This prevents erroneous results when the averaging window is
130   // initially being populated for example. The suppression argument disables
131   // further recognitions for a set time after one has been triggered, which can
132   // help reduce spurious recognitions.
133   explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
134                              int32_t average_window_duration_ms = 1000,
135                              uint8_t detection_threshold = 200,
136                              int32_t suppression_ms = 1500,
137                              int32_t minimum_count = 3);
138 
139   // Call this with the results of running a model on sample data.
140   TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
141                                     const int32_t current_time_ms,
142                                     const char** found_command, uint8_t* score,
143                                     bool* is_new_command);
144 
145  private:
146   // Configuration
147   tflite::ErrorReporter* error_reporter_;
148   int32_t average_window_duration_ms_;
149   uint8_t detection_threshold_;
150   int32_t suppression_ms_;
151   int32_t minimum_count_;
152 
153   // Working variables
154   PreviousResultsQueue previous_results_;
155   const char* previous_top_label_;
156   int32_t previous_top_label_time_;
157 };
158 
159 #endif  // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
160