1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/examples/micro_speech/audio_provider.h"
17
18 #include <AudioToolbox/AudioToolbox.h>
19
20 #include "tensorflow/lite/micro/examples/micro_speech/simple_features/simple_model_settings.h"
21
22 namespace {
23
24 constexpr int kNumberRecordBuffers = 3;
25 bool g_is_audio_initialized = false;
26 constexpr int kAudioCaptureBufferSize = kAudioSampleFrequency * 0.5;
27 int16_t g_audio_capture_buffer[kAudioCaptureBufferSize];
28 int16_t g_audio_output_buffer[kMaxAudioSampleSize];
29 int32_t g_latest_audio_timestamp = 0;
30
31 // Checks for MacOS errors, prints information and returns a TF Lite version.
32 #define RETURN_IF_OS_ERROR(error, error_reporter) \
33 do { \
34 if (error != noErr) { \
35 TF_LITE_REPORT_ERROR(error_reporter, "Error: %s:%d (%d)\n", __FILE__, \
36 __LINE__, error); \
37 return kTfLiteError; \
38 } \
39 } while (0);
40
41 // Called when an audio input buffer has been filled.
OnAudioBufferFilledCallback(void * user_data,AudioQueueRef queue,AudioQueueBufferRef buffer,const AudioTimeStamp * start_time,UInt32 num_packets,const AudioStreamPacketDescription * packet_description)42 void OnAudioBufferFilledCallback(
43 void* user_data, AudioQueueRef queue, AudioQueueBufferRef buffer,
44 const AudioTimeStamp* start_time, UInt32 num_packets,
45 const AudioStreamPacketDescription* packet_description) {
46 const int sample_size = buffer->mAudioDataByteSize / sizeof(float);
47 const int64_t sample_offset = start_time->mSampleTime;
48 const int32_t time_in_ms =
49 (sample_offset + sample_size) / (kAudioSampleFrequency / 1000);
50 const float* float_samples = static_cast<const float*>(buffer->mAudioData);
51 for (int i = 0; i < sample_size; ++i) {
52 const int capture_index = (sample_offset + i) % kAudioCaptureBufferSize;
53 g_audio_capture_buffer[capture_index] = float_samples[i] * ((1 << 15) - 1);
54 }
55 // This is how we let the outside world know that new audio data has arrived.
56 g_latest_audio_timestamp = time_in_ms;
57 AudioQueueEnqueueBuffer(queue, buffer, 0, nullptr);
58 }
59
60 // Set up everything we need to capture audio samples from the default recording
61 // device on MacOS.
InitAudioRecording(tflite::ErrorReporter * error_reporter)62 TfLiteStatus InitAudioRecording(tflite::ErrorReporter* error_reporter) {
63 // Set up the format of the audio - single channel, 32-bit float at 16KHz.
64 AudioStreamBasicDescription recordFormat = {};
65 recordFormat.mSampleRate = kAudioSampleFrequency;
66 recordFormat.mFormatID = kAudioFormatLinearPCM;
67 recordFormat.mFormatFlags =
68 kAudioFormatFlagIsFloat | kAudioFormatFlagIsPacked;
69 recordFormat.mBitsPerChannel = 8 * sizeof(float);
70 recordFormat.mChannelsPerFrame = 1;
71 recordFormat.mBytesPerFrame = sizeof(float) * recordFormat.mChannelsPerFrame;
72 recordFormat.mFramesPerPacket = 1;
73 recordFormat.mBytesPerPacket =
74 recordFormat.mBytesPerFrame * recordFormat.mFramesPerPacket;
75 recordFormat.mReserved = 0;
76
77 UInt32 propSize = sizeof(recordFormat);
78 RETURN_IF_OS_ERROR(AudioFormatGetProperty(kAudioFormatProperty_FormatInfo, 0,
79 NULL, &propSize, &recordFormat),
80 error_reporter);
81
82 // Create a recording queue.
83 AudioQueueRef queue;
84 RETURN_IF_OS_ERROR(
85 AudioQueueNewInput(&recordFormat, OnAudioBufferFilledCallback,
86 error_reporter, nullptr, nullptr, 0, &queue),
87 error_reporter);
88
89 // Set up the buffers we'll need.
90 int buffer_bytes = 512 * sizeof(float);
91 for (int i = 0; i < kNumberRecordBuffers; ++i) {
92 AudioQueueBufferRef buffer;
93 RETURN_IF_OS_ERROR(AudioQueueAllocateBuffer(queue, buffer_bytes, &buffer),
94 error_reporter);
95 RETURN_IF_OS_ERROR(AudioQueueEnqueueBuffer(queue, buffer, 0, nullptr),
96 error_reporter);
97 }
98
99 // Start capturing audio.
100 RETURN_IF_OS_ERROR(AudioQueueStart(queue, nullptr), error_reporter);
101
102 return kTfLiteOk;
103 }
104
105 } // namespace
106
GetAudioSamples(tflite::ErrorReporter * error_reporter,int start_ms,int duration_ms,int * audio_samples_size,int16_t ** audio_samples)107 TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
108 int start_ms, int duration_ms,
109 int* audio_samples_size, int16_t** audio_samples) {
110 if (!g_is_audio_initialized) {
111 TfLiteStatus init_status = InitAudioRecording(error_reporter);
112 if (init_status != kTfLiteOk) {
113 return init_status;
114 }
115 for (int i = 0; i < kMaxAudioSampleSize; ++i) {
116 g_audio_output_buffer[i] = 0;
117 }
118 g_is_audio_initialized = true;
119 }
120 // This should only be called when the main thread notices that the latest
121 // audio sample data timestamp has changed, so that there's new data in the
122 // capture ring buffer. The ring buffer will eventually wrap around and
123 // overwrite the data, but the assumption is that the main thread is checking
124 // often enough and the buffer is large enough that this call will be made
125 // before that happens.
126 const int start_offset = start_ms * (kAudioSampleFrequency / 1000);
127 const int duration_sample_count =
128 duration_ms * (kAudioSampleFrequency / 1000);
129 for (int i = 0; i < duration_sample_count; ++i) {
130 const int capture_index = (start_offset + i) % kAudioCaptureBufferSize;
131 g_audio_output_buffer[i] = g_audio_capture_buffer[capture_index];
132 }
133
134 *audio_samples_size = kMaxAudioSampleSize;
135 *audio_samples = g_audio_output_buffer;
136 return kTfLiteOk;
137 }
138
LatestAudioTimestamp()139 int32_t LatestAudioTimestamp() { return g_latest_audio_timestamp; }
140