1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 #include "tensorflow/lite/micro/kernels/activation_utils.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27 #include "tensorflow/lite/micro/kernels/svdf.h"
28 #include "tensorflow/lite/micro/micro_utils.h"
29 
30 namespace tflite {
31 
32 /**
33  * This version of SVDF is specific to TFLite Micro. It contains the following
34  * differences between the TFLite version:
35  *
36  * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time
37  * for the Micro interpreter.
38  * 2.) Output dimensions - the TFLite version determines output size and runtime
39  * and resizes the output tensor. Micro runtime does not support tensor
40  * resizing.
41  */
42 
43 const int kSvdfInputTensor = 0;
44 const int kSvdfWeightsFeatureTensor = 1;
45 const int kSvdfWeightsTimeTensor = 2;
46 const int kSvdfBiasTensor = 3;
47 const int kSvdfInputActivationStateTensor =
48     4;  // This is a variable tensor, and will be modified by this op.
49 const int kSvdfOutputTensor = 0;
50 
EvalIntegerSvdfReference(TfLiteContext * context,TfLiteNode * node,const TfLiteEvalTensor * input_tensor,const TfLiteEvalTensor * weights_feature_tensor,const TfLiteEvalTensor * weights_time_tensor,const TfLiteEvalTensor * bias_tensor,const TfLiteSVDFParams * params,TfLiteEvalTensor * activation_state_tensor,TfLiteEvalTensor * output_tensor,const OpData & data)51 void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
52                               const TfLiteEvalTensor* input_tensor,
53                               const TfLiteEvalTensor* weights_feature_tensor,
54                               const TfLiteEvalTensor* weights_time_tensor,
55                               const TfLiteEvalTensor* bias_tensor,
56                               const TfLiteSVDFParams* params,
57                               TfLiteEvalTensor* activation_state_tensor,
58                               TfLiteEvalTensor* output_tensor,
59                               const OpData& data) {
60   const int n_rank = params->rank;
61   const int n_batch = input_tensor->dims->data[0];
62   const int n_input = input_tensor->dims->data[1];
63   const int n_filter = weights_feature_tensor->dims->data[0];
64   const int n_unit = n_filter / n_rank;
65   const int n_memory = weights_time_tensor->dims->data[1];
66 
67   TFLITE_DCHECK(context != nullptr);
68   TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
69 
70   int32_t* scratch_tensor = static_cast<int32_t*>(
71       context->GetScratchBuffer(context, data.scratch_tensor_index));
72   int32_t* scratch_output_tensor = static_cast<int32_t*>(
73       context->GetScratchBuffer(context, data.scratch_output_tensor_index));
74 
75   // Shift states.
76   int16_t* const state_ptr =
77       tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
78 
79   // Left shift the activation_state.
80   {
81     int16_t* new_state_start = state_ptr;
82     const int16_t* old_state_start = state_ptr + 1;
83     const int16_t* old_state_end = state_ptr + n_batch * n_filter * n_memory;
84     while (old_state_start != old_state_end) {
85       *new_state_start++ = *old_state_start++;
86     }
87   }
88 
89   // Note: no need to clear the latest activation, matmul is not accumulative.
90 
91   // Feature matmul.
92   {
93     int16_t* state =
94         tflite::micro::GetTensorData<int16_t>(activation_state_tensor);
95     const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
96     const int8_t* weight_feature =
97         tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
98     const int32_t output_max = std::numeric_limits<int16_t>::max();
99     const int32_t output_min = std::numeric_limits<int16_t>::min();
100     int16_t* result_in_batch = state + (n_memory - 1);
101     for (int b = 0; b < n_batch; b++) {
102       const int8_t* matrix_ptr = weight_feature;
103       for (int r = 0; r < n_filter; r++) {
104         int32_t dot_prod = 0;
105         const int8_t* vector_in_batch = input + b * n_input;
106         for (int c = 0; c < n_input; c++) {
107           dot_prod +=
108               *matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
109         }
110         dot_prod = MultiplyByQuantizedMultiplier(
111             dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
112         dot_prod = std::min(std::max(output_min, dot_prod), output_max);
113         // This assumes state is symmetrically quantized. Otherwise last bit of
114         // state should be initialized to its zero point and accumulate the
115         // dot_prod.
116         // Equivalent as the following:
117         //     result_in_batch = zero point, which happens to be zero.
118         //     result_in_batch += dot_prod_56.
119         *result_in_batch = dot_prod;
120         result_in_batch += n_memory;
121       }
122     }
123   }
124 
125   // Time.
126   {
127     for (int b = 0; b < n_batch; ++b) {
128       int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
129 
130       // Perform batched vector dot product:
131       const int16_t* vector1_ptr =
132           tflite::micro::GetTensorData<int16_t>(weights_time_tensor);
133       const int16_t* vector2_ptr =
134           tflite::micro::GetTensorData<int16_t>(activation_state_tensor) +
135           b * n_memory * n_filter;
136 
137       for (int i = 0; i < n_filter; i++) {
138         *scratch_ptr_batch = 0;
139         for (int j = 0; j < n_memory; j++) {
140           *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
141         }
142         scratch_ptr_batch++;
143       }
144     }
145   }
146 
147   // Reduce, add bias, rescale, activation.
148   {
149     // Add bias.
150     if (bias_tensor) {
151       // Vector batch assign:
152       const int32_t* bias_data =
153           tflite::micro::GetTensorData<int32_t>(bias_tensor);
154       for (int i = 0; i < n_batch; ++i) {
155         int32_t* output_ptr = scratch_output_tensor + i * n_unit;
156         const int32_t* bias_ptr = bias_data;
157         for (int j = 0; j < n_unit; ++j) {
158           *output_ptr++ = *bias_ptr++;
159         }
160       }
161     } else {
162       int32_t* output_ptr = scratch_output_tensor;
163       for (int i = 0; i < n_batch * n_unit; ++i) {
164         *output_ptr++ = 0;
165       }
166     }
167 
168     // Reduce.
169     for (int b = 0; b < n_batch; ++b) {
170       int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
171       int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
172 
173       // Reduction sum vector
174       for (int i = 0; i < n_unit; ++i) {
175         for (int j = 0; j < n_rank; ++j) {
176           output_temp_ptr[i] += *scratch_ptr_batch++;
177         }
178       }
179     }
180 
181     // Rescale.
182     const int32_t output_max = std::numeric_limits<int8_t>::max();
183     const int32_t output_min = std::numeric_limits<int8_t>::min();
184     for (int i = 0; i < n_batch * n_unit; ++i) {
185       int32_t x1 = scratch_output_tensor[i];
186       int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
187                                                  data.effective_scale_2_b);
188       int32_t x3 = x2 + data.output_zero_point;
189       int32_t x4 = std::min(std::max(output_min, x3), output_max);
190       tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
191           static_cast<int8_t>(x4);
192     }
193   }
194 }
ApplyTimeWeightsBiasAndActivation(int batch_size,int memory_size,int num_filters,int num_units,int rank,const float * const __restrict__ weights_time_ptr,const float * const __restrict__ bias_ptr,TfLiteFusedActivation activation,float * const __restrict__ state_ptr,float * const __restrict__ scratch_ptr,float * const __restrict__ output_ptr)195 static inline void ApplyTimeWeightsBiasAndActivation(
196     int batch_size, int memory_size, int num_filters, int num_units, int rank,
197     const float* const __restrict__ weights_time_ptr,
198     const float* const __restrict__ bias_ptr, TfLiteFusedActivation activation,
199     float* const __restrict__ state_ptr, float* const __restrict__ scratch_ptr,
200     float* const __restrict__ output_ptr) {
201   // Compute matmul(activation_state, weights_time).
202   for (int b = 0; b < batch_size; ++b) {
203     // Perform batched vector dot product:
204     float* scratch_ptr_batch = scratch_ptr + b * num_filters;
205     const float* vector1_ptr = weights_time_ptr;
206     const float* vector2_ptr = state_ptr + b * memory_size * num_filters;
207     for (int i = 0; i < num_filters; ++i) {
208       *scratch_ptr_batch = 0.f;
209       for (int j = 0; j < memory_size; ++j) {
210         *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
211       }
212       scratch_ptr_batch++;
213     }
214   }
215 
216   // Initialize output with bias if provided.
217   if (bias_ptr) {
218     // VectorBatchVectorAssign
219     for (int i = 0; i < batch_size; ++i) {
220       float* output_data = output_ptr + i * num_units;
221       const float* bias_data = bias_ptr;
222       for (int j = 0; j < num_units; ++j) {
223         *output_data++ = *bias_data++;
224       }
225     }
226   } else {
227     float* output_data = output_ptr;
228     for (int i = 0; i < batch_size * num_units; ++i) {
229       *output_data++ = 0.0f;
230     }
231   }
232 
233   // Reduction sum.
234   for (int b = 0; b < batch_size; ++b) {
235     float* output_ptr_batch = output_ptr + b * num_units;
236     float* scratch_ptr_batch = scratch_ptr + b * num_filters;
237 
238     // Reduction sum vector
239     for (int i = 0; i < num_units; ++i) {
240       for (int j = 0; j < rank; j++) {
241         output_ptr_batch[i] += *scratch_ptr_batch++;
242       }
243     }
244   }
245 
246   // Apply activation.
247   for (int b = 0; b < batch_size; ++b) {
248     float* output_ptr_batch = output_ptr + b * num_units;
249     for (int i = 0; i < num_units; ++i) {
250       *output_ptr_batch =
251           tflite::ops::micro::ActivationValFloat(activation, *output_ptr_batch);
252       ++output_ptr_batch;
253     }
254   }
255 }
256 
EvalFloatSvdfReference(TfLiteContext * context,TfLiteNode * node,const TfLiteEvalTensor * input,const TfLiteEvalTensor * weights_feature,const TfLiteEvalTensor * weights_time,const TfLiteEvalTensor * bias,const TfLiteSVDFParams * params,int scratch_tensor_index,TfLiteEvalTensor * activation_state,TfLiteEvalTensor * output)257 void EvalFloatSvdfReference(
258     TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
259     const TfLiteEvalTensor* weights_feature,
260     const TfLiteEvalTensor* weights_time, const TfLiteEvalTensor* bias,
261     const TfLiteSVDFParams* params, int scratch_tensor_index,
262     TfLiteEvalTensor* activation_state, TfLiteEvalTensor* output) {
263   const int rank = params->rank;
264   const int batch_size = input->dims->data[0];
265   const int input_size = input->dims->data[1];
266   const int num_filters = weights_feature->dims->data[0];
267   const int num_units = num_filters / rank;
268   const int memory_size = weights_time->dims->data[1];
269 
270   const float* weights_feature_ptr =
271       tflite::micro::GetTensorData<float>(weights_feature);
272   const float* weights_time_ptr =
273       tflite::micro::GetTensorData<float>(weights_time);
274   const float* bias_ptr = tflite::micro::GetTensorData<float>(bias);
275   const float* input_ptr = tflite::micro::GetTensorData<float>(input);
276 
277   float* state_ptr = tflite::micro::GetTensorData<float>(activation_state);
278 
279   TFLITE_DCHECK(context != nullptr);
280   TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
281 
282   float* scratch_ptr = static_cast<float*>(
283       context->GetScratchBuffer(context, scratch_tensor_index));
284 
285   float* output_ptr = tflite::micro::GetTensorData<float>(output);
286 
287   // Left shift the activation_state.
288   {
289     float* new_state_start = state_ptr;
290     const float* old_state_start = state_ptr + 1;
291     const float* old_state_end =
292         state_ptr + batch_size * num_filters * memory_size;
293     while (old_state_start != old_state_end) {
294       *new_state_start++ = *old_state_start++;
295     }
296   }
297 
298   // Note: no need to clear the latest activation, matmul is not accumulative.
299 
300   // Compute conv1d(inputs, weights_feature).
301   // The activation_state's rightmost column is used to save current cycle
302   // activation. This is achieved by starting at state_ptr[memory_size - 1] and
303   // having the stride equal to memory_size.
304 
305   // Perform batched matrix vector multiply operation:
306   {
307     const float* matrix = weights_feature_ptr;
308     const float* vector = input_ptr;
309     float* result = &state_ptr[memory_size - 1];
310     float* result_in_batch = result;
311     for (int i = 0; i < batch_size; ++i) {
312       const float* matrix_ptr = matrix;
313       for (int j = 0; j < num_filters; ++j) {
314         float dot_prod = 0.0f;
315         const float* vector_in_batch = vector + i * input_size;
316         for (int k = 0; k < input_size; ++k) {
317           dot_prod += *matrix_ptr++ * *vector_in_batch++;
318         }
319         *result_in_batch = dot_prod;
320         result_in_batch += memory_size;
321       }
322     }
323   }
324 
325   ApplyTimeWeightsBiasAndActivation(
326       batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr,
327       bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
328 }
329 
PrepareSvdf(TfLiteContext * context,TfLiteNode * node)330 TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
331   TFLITE_DCHECK(node->builtin_data != nullptr);
332 
333   const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
334 
335   // Validate Tensor Inputs (dtype depends on quantization):
336   // [0] = Input, {2, batch_size, input_size}
337   // [1] = Weights Feature, {2, num_filters, input_size}
338   // [2] = Weights Time, {2, num_filters, memory_size}
339   // [3] = Bias (optional), {1, num_units}
340   // [4] = Activation State (variable),
341   //         {2, batch_size, memory_size * num_filters}
342   const TfLiteTensor* input = GetInput(context, node, kSvdfInputTensor);
343   TF_LITE_ENSURE(context, input != nullptr);
344   const TfLiteTensor* weights_feature =
345       GetInput(context, node, kSvdfWeightsFeatureTensor);
346   TF_LITE_ENSURE(context, weights_feature != nullptr);
347   const TfLiteTensor* weights_time =
348       GetInput(context, node, kSvdfWeightsTimeTensor);
349   TF_LITE_ENSURE(context, weights_time != nullptr);
350   const TfLiteTensor* bias =
351       GetOptionalInputTensor(context, node, kSvdfBiasTensor);
352   const TfLiteTensor* activation_state =
353       GetInput(context, node, kSvdfInputActivationStateTensor);
354   TF_LITE_ENSURE(context, activation_state != nullptr);
355 
356   // Define input constants based on input tensor definition above:
357   const int rank = params->rank;
358   const int input_size = input->dims->data[1];
359   const int batch_size = input->dims->data[0];
360   const int num_filters = weights_feature->dims->data[0];
361   TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
362   const int num_units = num_filters / rank;
363   const int memory_size = weights_time->dims->data[1];
364 
365   // Validate Input Tensor:
366   TF_LITE_ENSURE(context,
367                  input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
368   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
369 
370   // Validate Tensor Output:
371   // [0] = float/int8_t, {2, batch_size, num_units}
372   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
373   TfLiteTensor* output = GetOutput(context, node, kSvdfOutputTensor);
374   TF_LITE_ENSURE(context, output != nullptr);
375   TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
376   TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
377   TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
378 
379   // Validate Weights Feature Input Tensor:
380   TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
381   TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
382 
383   // Validate Weights Time Input Tensor:
384   TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
385   TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
386   TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
387 
388   // Validate Optional Bias Input Tensor:
389   if (bias != nullptr) {
390     TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
391   }
392 
393   // Validate Activation State Input Tensor:
394   TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
395   TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
396   TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
397                     memory_size * num_filters);
398   // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
399   TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
400 
401   TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
402 
403   TFLITE_DCHECK(node->user_data != nullptr);
404   OpData* data = static_cast<OpData*>(node->user_data);
405 
406   if (input->type == kTfLiteInt8) {
407     TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
408     TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
409     TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
410     if (bias != nullptr) {
411       TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
412     }
413 
414     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
415 
416     const double effective_scale_1 = static_cast<double>(
417         input->params.scale * weights_feature->params.scale /
418         activation_state->params.scale);
419     const double effective_scale_2 =
420         static_cast<double>(activation_state->params.scale *
421                             weights_time->params.scale / output->params.scale);
422 
423     // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
424     TF_LITE_ENSURE(
425         context,
426         std::abs(static_cast<double>(bias->params.scale) -
427                  static_cast<double>(activation_state->params.scale *
428                                      weights_time->params.scale)) < 1e-5);
429 
430     QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
431                        &(data->effective_scale_1_b));
432     QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
433                        &(data->effective_scale_2_b));
434 
435     data->input_zero_point = input->params.zero_point;
436     data->output_zero_point = output->params.zero_point;
437 
438     TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
439 
440     const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
441         context, batch_size * num_filters * sizeof(int32_t),
442         &(data->scratch_tensor_index));
443     TF_LITE_ENSURE_OK(context, scratch_status);
444 
445     const TfLiteStatus scratch_output_status =
446         context->RequestScratchBufferInArena(
447             context, batch_size * num_units * sizeof(int32_t),
448             &(data->scratch_output_tensor_index));
449     TF_LITE_ENSURE_OK(context, scratch_output_status);
450   } else {
451     TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
452     TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
453     TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
454     if (bias != nullptr) {
455       TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
456     }
457     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
458 
459     TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
460     const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
461         context, batch_size * num_filters * sizeof(float),
462         &(data->scratch_tensor_index));
463     TF_LITE_ENSURE_OK(context, scratch_status);
464   }
465 
466   return kTfLiteOk;
467 }
468 
469 }  // namespace tflite
470