1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ 17 #define TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ 18 19 #include <cstdint> 20 21 #include "tensorflow/lite/c/common.h" 22 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h" 23 #include "tensorflow/lite/micro/micro_error_reporter.h" 24 25 // Partial implementation of std::dequeue, just providing the functionality 26 // that's needed to keep a record of previous neural network results over a 27 // short time period, so they can be averaged together to produce a more 28 // accurate overall prediction. This doesn't use any dynamic memory allocation 29 // so it's a better fit for microcontroller applications, but this does mean 30 // there are hard limits on the number of results it can store. 31 class PreviousResultsQueue { 32 public: PreviousResultsQueue(tflite::ErrorReporter * error_reporter)33 PreviousResultsQueue(tflite::ErrorReporter* error_reporter) 34 : error_reporter_(error_reporter), front_index_(0), size_(0) {} 35 36 // Data structure that holds an inference result, and the time when it 37 // was recorded. 38 struct Result { ResultResult39 Result() : time_(0), scores() {} ResultResult40 Result(int32_t time, int8_t* input_scores) : time_(time) { 41 for (int i = 0; i < kCategoryCount; ++i) { 42 scores[i] = input_scores[i]; 43 } 44 } 45 int32_t time_; 46 int8_t scores[kCategoryCount]; 47 }; 48 size()49 int size() { return size_; } empty()50 bool empty() { return size_ == 0; } front()51 Result& front() { return results_[front_index_]; } back()52 Result& back() { 53 int back_index = front_index_ + (size_ - 1); 54 if (back_index >= kMaxResults) { 55 back_index -= kMaxResults; 56 } 57 return results_[back_index]; 58 } 59 push_back(const Result & entry)60 void push_back(const Result& entry) { 61 if (size() >= kMaxResults) { 62 TF_LITE_REPORT_ERROR( 63 error_reporter_, 64 "Couldn't push_back latest result, too many already!"); 65 return; 66 } 67 size_ += 1; 68 back() = entry; 69 } 70 pop_front()71 Result pop_front() { 72 if (size() <= 0) { 73 TF_LITE_REPORT_ERROR(error_reporter_, 74 "Couldn't pop_front result, none present!"); 75 return Result(); 76 } 77 Result result = front(); 78 front_index_ += 1; 79 if (front_index_ >= kMaxResults) { 80 front_index_ = 0; 81 } 82 size_ -= 1; 83 return result; 84 } 85 86 // Most of the functions are duplicates of dequeue containers, but this 87 // is a helper that makes it easy to iterate through the contents of the 88 // queue. from_front(int offset)89 Result& from_front(int offset) { 90 if ((offset < 0) || (offset >= size_)) { 91 TF_LITE_REPORT_ERROR(error_reporter_, 92 "Attempt to read beyond the end of the queue!"); 93 offset = size_ - 1; 94 } 95 int index = front_index_ + offset; 96 if (index >= kMaxResults) { 97 index -= kMaxResults; 98 } 99 return results_[index]; 100 } 101 102 private: 103 tflite::ErrorReporter* error_reporter_; 104 static constexpr int kMaxResults = 50; 105 Result results_[kMaxResults]; 106 107 int front_index_; 108 int size_; 109 }; 110 111 // This class is designed to apply a very primitive decoding model on top of the 112 // instantaneous results from running an audio recognition model on a single 113 // window of samples. It applies smoothing over time so that noisy individual 114 // label scores are averaged, increasing the confidence that apparent matches 115 // are real. 116 // To use it, you should create a class object with the configuration you 117 // want, and then feed results from running a TensorFlow model into the 118 // processing method. The timestamp for each subsequent call should be 119 // increasing from the previous, since the class is designed to process a stream 120 // of data over time. 121 class RecognizeCommands { 122 public: 123 // labels should be a list of the strings associated with each one-hot score. 124 // The window duration controls the smoothing. Longer durations will give a 125 // higher confidence that the results are correct, but may miss some commands. 126 // The detection threshold has a similar effect, with high values increasing 127 // the precision at the cost of recall. The minimum count controls how many 128 // results need to be in the averaging window before it's seen as a reliable 129 // average. This prevents erroneous results when the averaging window is 130 // initially being populated for example. The suppression argument disables 131 // further recognitions for a set time after one has been triggered, which can 132 // help reduce spurious recognitions. 133 explicit RecognizeCommands(tflite::ErrorReporter* error_reporter, 134 int32_t average_window_duration_ms = 1000, 135 uint8_t detection_threshold = 200, 136 int32_t suppression_ms = 1500, 137 int32_t minimum_count = 3); 138 139 // Call this with the results of running a model on sample data. 140 TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results, 141 const int32_t current_time_ms, 142 const char** found_command, uint8_t* score, 143 bool* is_new_command); 144 145 private: 146 // Configuration 147 tflite::ErrorReporter* error_reporter_; 148 int32_t average_window_duration_ms_; 149 uint8_t detection_threshold_; 150 int32_t suppression_ms_; 151 int32_t minimum_count_; 152 153 // Working variables 154 PreviousResultsQueue previous_results_; 155 const char* previous_top_label_; 156 int32_t previous_top_label_time_; 157 }; 158 159 #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ 160