1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17 
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23 
24 #include <functional>
25 
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 
31 namespace tflite {
32 
33 constexpr int kReverseShift = -1;
34 
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36                                 float* output_activation_min,
37                                 float* output_activation_max) {
38   switch (ac) {
39     case FusedActivationFunctionType::kNone:
40       *output_activation_min = std::numeric_limits<float>::lowest();
41       *output_activation_max = std::numeric_limits<float>::max();
42       break;
43     case FusedActivationFunctionType::kRelu:
44       *output_activation_min = 0.f;
45       *output_activation_max = std::numeric_limits<float>::max();
46       break;
47     case FusedActivationFunctionType::kRelu1:
48       *output_activation_min = -1.f;
49       *output_activation_max = 1.f;
50       break;
51     case FusedActivationFunctionType::kRelu6:
52       *output_activation_min = 0.f;
53       *output_activation_max = 6.f;
54       break;
55   }
56 }
57 
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60                                       T output_activation_max) {
61   using std::max;
62   using std::min;
63   return min(max(x, output_activation_min), output_activation_max);
64 }
65 
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69   float output_activation_min, output_activation_max;
70   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71   return ActivationFunctionWithMinMax(x, output_activation_min,
72                                       output_activation_max);
73 }
74 
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76                          const float* bias_data, int array_size,
77                          float* array_data) {
78   // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79   // this with the Eigen one-liner:
80   //   return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81   // This turned out to severely regress performance: +4ms (i.e. 8%) on
82   // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85   float* array_ptr = array_data;
86   float* array_end_ptr = array_ptr + array_size;
87   const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88   const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90     int i = 0;
91     for (; i <= bias_size - 16; i += 16) {
92       auto b0 = vld1q_f32(bias_data + i);
93       auto b1 = vld1q_f32(bias_data + i + 4);
94       auto b2 = vld1q_f32(bias_data + i + 8);
95       auto b3 = vld1q_f32(bias_data + i + 12);
96       auto a0 = vld1q_f32(array_ptr + i);
97       auto a1 = vld1q_f32(array_ptr + i + 4);
98       auto a2 = vld1q_f32(array_ptr + i + 8);
99       auto a3 = vld1q_f32(array_ptr + i + 12);
100       auto x0 = vaddq_f32(a0, b0);
101       auto x1 = vaddq_f32(a1, b1);
102       auto x2 = vaddq_f32(a2, b2);
103       auto x3 = vaddq_f32(a3, b3);
104       x0 = vmaxq_f32(clamp_min_vec, x0);
105       x1 = vmaxq_f32(clamp_min_vec, x1);
106       x2 = vmaxq_f32(clamp_min_vec, x2);
107       x3 = vmaxq_f32(clamp_min_vec, x3);
108       x0 = vminq_f32(clamp_max_vec, x0);
109       x1 = vminq_f32(clamp_max_vec, x1);
110       x2 = vminq_f32(clamp_max_vec, x2);
111       x3 = vminq_f32(clamp_max_vec, x3);
112       vst1q_f32(array_ptr + i, x0);
113       vst1q_f32(array_ptr + i + 4, x1);
114       vst1q_f32(array_ptr + i + 8, x2);
115       vst1q_f32(array_ptr + i + 12, x3);
116     }
117     for (; i <= bias_size - 4; i += 4) {
118       auto b = vld1q_f32(bias_data + i);
119       auto a = vld1q_f32(array_ptr + i);
120       auto x = vaddq_f32(a, b);
121       x = vmaxq_f32(clamp_min_vec, x);
122       x = vminq_f32(clamp_max_vec, x);
123       vst1q_f32(array_ptr + i, x);
124     }
125     for (; i < bias_size; i++) {
126       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127                                                   clamp_min, clamp_max);
128     }
129   }
130 #else  // not NEON
131   for (int array_offset = 0; array_offset < array_size;
132        array_offset += bias_size) {
133     for (int i = 0; i < bias_size; i++) {
134       array_data[array_offset + i] = ActivationFunctionWithMinMax(
135           array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136     }
137   }
138 #endif
139 }
140 
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142     int32_t x, int32_t quantized_multiplier, int left_shift) {
143   using gemmlowp::RoundingDivideByPOT;
144   using gemmlowp::SaturatingRoundingDoublingHighMul;
145   return RoundingDivideByPOT(
146       SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148 
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150     int32_t x, int32_t quantized_multiplier, int left_shift) {
151   using gemmlowp::SaturatingRoundingDoublingHighMul;
152   return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153                                            quantized_multiplier);
154 }
155 
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157                                              int32_t quantized_multiplier,
158                                              int shift) {
159   using gemmlowp::RoundingDivideByPOT;
160   using gemmlowp::SaturatingRoundingDoublingHighMul;
161   int left_shift = shift > 0 ? shift : 0;
162   int right_shift = shift > 0 ? 0 : -shift;
163   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164                                  x * (1 << left_shift), quantized_multiplier),
165                              right_shift);
166 }
167 
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169                                              int32_t quantized_multiplier,
170                                              int shift) {
171   // Inputs:
172   // - quantized_multiplier has fixed point at bit 31
173   // - shift is -31 to +7 (negative for right shift)
174   //
175   // Assumptions: The following input ranges are assumed
176   // - quantize_scale>=0  (the usual range is (1<<30) to (1>>31)-1)
177   // - scaling is chosen so final scaled result fits in int32_t
178   // - input x is in the range -(1<<47) <= x < (1<<47)
179   assert(quantized_multiplier >= 0);
180   assert(shift >= -31 && shift < 8);
181   assert(x >= -(static_cast<int64_t>(1) << 47) &&
182          x < (static_cast<int64_t>(1) << 47));
183 
184   int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185                                    ? ((quantized_multiplier + (1 << 15)) >> 16)
186                                    : 0x7FFF;
187   int total_shift = 15 - shift;
188   x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189   int32_t result = x >> total_shift;
190   return result;
191 }
192 
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196     int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197   const int left_shift = std::max(shift, 0);
198   const int right_shift = std::min(shift, 0);
199   int32x4x4_t result;
200 
201   int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202   int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203   int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204 
205   result.val[0] =
206       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207                                multiplier_dup),
208                  right_shift_dup);
209 
210   result.val[1] =
211       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212                                multiplier_dup),
213                  right_shift_dup);
214 
215   result.val[2] =
216       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217                                multiplier_dup),
218                  right_shift_dup);
219 
220   result.val[3] =
221       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222                                multiplier_dup),
223                  right_shift_dup);
224 
225   return result;
226 }
227 #endif
228 
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231   static_assert(std::is_unsigned<T>::value,
232                 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234   return integer_input ? __builtin_clz(integer_input)
235                        : std::numeric_limits<T>::digits;
236 #else
237   if (integer_input == 0) {
238     return std::numeric_limits<T>::digits;
239   }
240 
241   const T one_in_leading_positive = static_cast<T>(1)
242                                     << (std::numeric_limits<T>::digits - 1);
243   int leading_zeros = 0;
244   while (integer_input < one_in_leading_positive) {
245     integer_input <<= 1;
246     ++leading_zeros;
247   }
248   return leading_zeros;
249 #endif
250 }
251 
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254   static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256   return integer_input ? __builtin_clrsb(integer_input)
257                        : std::numeric_limits<T>::digits;
258 #else
259   using U = typename std::make_unsigned<T>::type;
260   return integer_input >= 0
261              ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262          : integer_input != std::numeric_limits<T>::min()
263              ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264              : 0;
265 #endif
266 }
267 
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271   static_assert(std::is_integral<Integer>::value, "");
272   static_assert(std::is_signed<Integer>::value, "");
273   static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274   TFLITE_CHECK_GT(n, 0);
275   if (sizeof(Integer) == 4) {
276     return 30 - CountLeadingSignBits(n);
277   } else {
278     return 62 - CountLeadingSignBits(n);
279   }
280 }
281 
282 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
283 // softmax
284 // func - the function to build the LUT for (e.g exp(x))
285 // min,max - table limits
286 // table - pointer to buffer
287 // num - number of elements in the LUT
gen_lut(double (* func)(double),double min,double max,int16_t * table,const int num)288 inline void gen_lut(double (*func)(double), double min, double max,
289                     int16_t* table, const int num) {
290   // size of table should equal to num + 1
291   // last element only for slope calculation
292   double step = (max - min) / (num - 1);
293   double half_step = step / 2.0;
294   for (int i = 0; i < num - 1; i++) {
295     double sample_val = TfLiteRound(func(min + i * step) * 32768.0);
296     double midpoint_interp_val =
297         TfLiteRound((func(min + (i + 1) * step) * 32768.0 +
298                      TfLiteRound(func(min + i * step) * 32768.0)) /
299                     2.0);
300     double midpoint_val =
301         TfLiteRound(func(min + i * step + half_step) * 32768.0);
302     double midpoint_err = midpoint_interp_val - midpoint_val;
303     double bias = TfLiteRound(midpoint_err / 2.0);
304     table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
305                                 32767.0);
306   }
307   table[num - 1] = std::min<double>(
308       std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
309 }
310 
311 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
312 // softmax
313 // func - the function to build the LUT for (e.g exp(x))
314 // min,max - table limits
315 // table - pointer to buffer
316 // num - number of elements in the LUT
gen_lut(float (* func)(float),float min,float max,int16_t * table,const int num)317 inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
318                     const int num) {
319   // size of table should equal to num + 1
320   // last element only for slope calculation
321   float step = (max - min) / (num - 1);
322   float half_step = step / 2.0f;
323   for (int i = 0; i < num - 1; i++) {
324     float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
325     float midpoint_interp_val =
326         TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
327                      TfLiteRound(func(min + i * step) * 32768.0f)) /
328                     2.0f);
329     float midpoint_val =
330         TfLiteRound(func(min + i * step + half_step) * 32768.0f);
331     float midpoint_err = midpoint_interp_val - midpoint_val;
332     float bias = TfLiteRound(midpoint_err / 2.0f);
333     table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
334                                32767.0f);
335   }
336   table[num - 1] = std::min<float>(
337       std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
338 }
339 
340 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
generic_int16_table_lookup(int16_t value,const int16_t * lut)341 inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
342   // 512 base value, lut[513] only for calculate slope
343   uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
344   assert(index < 512 && "LUT index out of range.");
345   int16_t offset = value & 0x7f;
346 
347   // base and slope are Q0.15
348   int16_t base = lut[index];
349   int16_t slope = lut[index + 1] - lut[index];
350 
351   // Q0.15 * Q0.7 = Q0.22
352   // Round and convert from Q0.22 to Q0.15
353   int32_t delta = (static_cast<int32_t>(slope) * offset + 64) >> 7;
354 
355   // Q0.15 + Q0.15
356   return base + delta;
357 }
358 
359 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
360 
361 // We use combined sigmoid and tanh look-up table, since
362 // tanh(x) = 2*sigmoid(2*x) -1.
363 // Both functions are symmetric, so the LUT table is only needed
364 // for the absolute value of the input.
365 static const uint16_t sigmoid_table_uint16[256] = {
366     32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
367     40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
368     46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
369     52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
370     56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
371     59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
372     61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
373     62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
374     63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
375     64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
376     64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
377     65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
378     65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
379     65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
380     65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
381     65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
382     65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
383     65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
384     65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
385     65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
386     65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
387     65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
388     65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
389     65534, 65534, 65535};
390 
391 // TODO(b/77858996): Add these to gemmlowp.
392 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)393 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
394   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
395   return a;
396 }
397 
398 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)399 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
400   std::int64_t a64 = a;
401   std::int64_t b64 = b;
402   std::int64_t sum = a64 + b64;
403   return static_cast<std::int32_t>(std::min(
404       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
405       std::max(
406           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
407           sum)));
408 }
409 
410 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)411 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
412     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
413     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
414   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
415       SaturatingAddNonGemmlowp(a.raw(), b.raw()));
416 }
417 
418 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)419 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
420   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
421   return a;
422 }
423 
424 template <>
SaturatingSub(std::int16_t a,std::int16_t b)425 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
426   std::int32_t a32 = a;
427   std::int32_t b32 = b;
428   std::int32_t diff = a32 - b32;
429   return static_cast<std::int16_t>(
430       std::min(static_cast<int32_t>(32767),
431                std::max(static_cast<int32_t>(-32768), diff)));
432 }
433 
434 template <>
SaturatingSub(std::int32_t a,std::int32_t b)435 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
436   std::int64_t a64 = a;
437   std::int64_t b64 = b;
438   std::int64_t diff = a64 - b64;
439   return static_cast<std::int32_t>(std::min(
440       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
441       std::max(
442           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
443           diff)));
444 }
445 
446 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)447 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
448     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
449     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
450   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
451       SaturatingSub(a.raw(), b.raw()));
452 }
453 // End section to be moved to gemmlowp.
454 
455 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)456 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
457   if (exponent == 0) {
458     return x;
459   }
460   using ScalarIntegerType =
461       typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
462   const IntegerType min =
463       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
464   const IntegerType max =
465       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
466   const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
467 
468   const std::int32_t threshold =
469       ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
470   const IntegerType positive_mask =
471       gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
472   const IntegerType negative_mask =
473       gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
474 
475   IntegerType result = gemmlowp::ShiftLeft(x, exponent);
476   result = gemmlowp::SelectUsingMask(positive_mask, max, result);
477   result = gemmlowp::SelectUsingMask(negative_mask, min, result);
478   return result;
479 }
480 
481 // If we want to leave IntegerBits fixed, then multiplication
482 // by a power of two has to be saturating/rounding, not exact anymore.
483 template <typename tRawType, int tIntegerBits>
484 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)485 SaturatingRoundingMultiplyByPOTParam(
486     gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
487   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
488       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
489 }
490 
491 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)492 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
493                                             int16_t* multiplier_int16_t) {
494   TFLITE_DCHECK_GE(multiplier_int32_t, 0);
495   static constexpr int32_t kRoundingOffset = 1 << 15;
496   if (multiplier_int32_t >=
497       std::numeric_limits<int32_t>::max() - kRoundingOffset) {
498     *multiplier_int16_t = std::numeric_limits<int16_t>::max();
499     return;
500   }
501   const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
502   TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
503   TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
504   *multiplier_int16_t = result;
505   TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
506 }
507 
508 // Minimum output bits to accommodate log of maximum input range.  It actually
509 // does not matter if one considers, say, [-64,64] or [-64,64).
510 //
511 // For example, run this through Octave:
512 // [0:127; ...
513 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
514 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)515 constexpr int min_log_x_output_bits(int input_bits) {
516   return input_bits > 90   ? 7
517          : input_bits > 44 ? 6
518          : input_bits > 21 ? 5
519          : input_bits > 10 ? 4
520          : input_bits > 4  ? 3
521          : input_bits > 1  ? 2
522                            : 1;
523 }
524 
525 // Although currently the name of this function says that it cannot handle
526 // values less than 1, in practice it can handle as low as 1/x_max, where
527 // x_max is the largest representable input.  In other words, the output range
528 // is symmetric.
529 template <int OutputIntegerBits, int InputIntegerBits>
530 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)531 log_x_for_x_greater_than_or_equal_to_1_impl(
532     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
533   // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
534   // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
535   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
536   // The reason for accumulating the result with an extra bit of headroom is
537   // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
538   // recip_denom will otherwise introduce an error.
539   static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
540   using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
541 
542   const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
543       FixedPoint0, 1488522236, std::log(2.0));
544   const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
545       FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
546   const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
547       FixedPoint0, 1518500250, std::sqrt(0.5));
548   const FixedPoint0 one_quarter =
549       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
550 
551   const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
552       FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
553   const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
554       FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
555   const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
556       FixedPoint0, 1057819769,
557       2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
558   const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
559       FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
560 
561   const FixedPointAccum shifted_quarter =
562       gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
563 
564   // Reinterpret the input value as Q0.31, because we will figure out the
565   // required shift "ourselves" instead of using, say, Rescale.
566   FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
567   // z_a_pow_2 = input_integer_bits - z_a_headroom;
568   int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
569   FixedPoint0 r_a_tmp =
570       SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
571   const int32_t r_a_raw =
572       SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
573   // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
574   // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
575   //                   InputIntegerBits - z_b_headroom - 0.25);
576   const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
577       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
578           static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
579           31 - kAccumIntegerBits)),
580       shifted_quarter);
581 
582   // z_b is treated like z_a, but premultiplying by sqrt(0.5).
583   FixedPoint0 z_b = z_a * sqrt_half;
584   int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
585   const int32_t r_b_raw =
586       SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
587   const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
588       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
589           static_cast<int32_t>(InputIntegerBits - z_b_headroom),
590           31 - kAccumIntegerBits)),
591       shifted_quarter);
592 
593   const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
594   const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
595       std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
596 
597   const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
598   FixedPoint0 q = r - sqrt_sqrt_half;
599   q = q + q;
600 
601   const FixedPoint0 common_sq = q * q;
602   const FixedPoint0 num = q * r + q * common_sq * alpha_n;
603   const FixedPoint0 denom_minus_one_0 =
604       p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
605   const FixedPoint0 recip_denom =
606       one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
607 
608   const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
609   return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
610                                               num_scaled * recip_denom);
611 }
612 
613 template <int OutputIntegerBits, int InputIntegerBits>
614 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)615 log_x_for_x_greater_than_or_equal_to_1(
616     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
617   static_assert(
618       OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
619       "Output integer bits must be sufficient to accommodate logs of inputs.");
620   return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
621                                                      InputIntegerBits>(
622       input_val);
623 }
624 
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)625 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
626                              int* num_bits_over_unit) {
627   int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
628   // This is the number of bits to the left of the binary point above 1.0.
629   // Consider x=1.25.  In that case shifted_scale=0.8 and
630   // no later adjustment will be needed.
631   *num_bits_over_unit = x_integer_digits - headroom_plus_one;
632   const int32_t shifted_sum_minus_one =
633       static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
634                            (static_cast<uint32_t>(1) << 31));
635 
636   gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
637       gemmlowp::one_over_one_plus_x_for_x_in_0_1(
638           gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
639   return shifted_scale.raw();
640 }
641 
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)642 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
643                                              int32_t* output_inv_sqrt,
644                                              int* output_shift) {
645   TFLITE_DCHECK_GE(input, 0);
646   if (input <= 1) {
647     // Handle the input value 1 separately to avoid overflow in that case
648     // in the general computation below (b/143972021). Also handle 0 as if it
649     // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
650     // but rare/unrealistic input value. We can expect both to occur in some
651     // incompletely trained models, but probably not in fully trained models.
652     *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
653     *output_shift = 0;
654     return;
655   }
656   TFLITE_DCHECK_GT(input, 1);
657   *output_shift = 11;
658   while (input >= (1 << 29)) {
659     input /= 4;
660     ++*output_shift;
661   }
662   const unsigned max_left_shift_bits =
663       CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
664   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
665   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
666   *output_shift -= left_shift_bit_pairs;
667   input <<= 2 * left_shift_bit_pairs;
668   TFLITE_DCHECK_GE(input, (1 << 27));
669   TFLITE_DCHECK_LT(input, (1 << 29));
670   using gemmlowp::FixedPoint;
671   using gemmlowp::Rescale;
672   using gemmlowp::SaturatingRoundingMultiplyByPOT;
673   // Using 3 integer bits gives us enough room for the internal arithmetic in
674   // this Newton-Raphson iteration.
675   using F3 = FixedPoint<int32_t, 3>;
676   using F0 = FixedPoint<int32_t, 0>;
677   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
678   const F3 fixedpoint_half_input =
679       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
680   const F3 fixedpoint_half_three =
681       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
682   // Newton-Raphson iteration
683   // Naive unoptimized starting guess: x = 1
684   F3 x = F3::One();
685   // Naive unoptimized number of iterations: 5
686   for (int i = 0; i < 5; i++) {
687     const F3 x3 = Rescale<3>(x * x * x);
688     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
689   }
690   const F0 fixedpoint_half_sqrt_2 =
691       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
692   x = x * fixedpoint_half_sqrt_2;
693   *output_inv_sqrt = x.raw();
694   if (*output_shift < 0) {
695     *output_inv_sqrt <<= -*output_shift;
696     *output_shift = 0;
697   }
698   // Convert right shift (right is positive) to left shift.
699   *output_shift *= reverse_shift;
700 }
701 
702 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
703 // BROADCASTING.
704 //
705 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
706 // rectangular array of numbers.
707 //
708 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
709 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
710 // to enable simple unoptimized implementations of element-wise broadcasting
711 // operations.
712 template <int N>
713 struct NdArrayDesc {
714   // The "extent" of each dimension. Indices along dimension d must be in the
715   // half-open interval [0, extents[d]).
716   int extents[N];
717 
718   // The number of *elements* (not bytes) between consecutive indices of each
719   // dimension.
720   int strides[N];
721 };
722 
723 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
724 // BROADCASTING.
725 //
726 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)727 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
728                             int i3) {
729   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
730   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
731   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
732   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
733   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
734          i3 * desc.strides[3];
735 }
736 
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])737 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
738   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
739          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
740          indexes[4] * desc.strides[4];
741 }
742 
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])743 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
744   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
745          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
746          indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
747          indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
748 }
749 
750 // Given the dimensions of the operands for an element-wise binary broadcast,
751 // adjusts them so that they can be directly iterated over with simple loops.
752 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
753 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
754 //
755 // This function assumes that the two input shapes are compatible up to
756 // broadcasting and the shorter one has already been prepended with 1s to be the
757 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
758 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
759 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
760 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
761 //
762 // When two shapes are compatible up to broadcasting, for each dimension d,
763 // the input extents are either equal, or one of them is 1.
764 //
765 // This function performs the following for each dimension d:
766 // - If the extents are equal, then do nothing since the loop that walks over
767 //   both of the input arrays is correct.
768 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
769 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
770 //   array0 to be referenced *at any index* in dimension d and still access the
771 //   same slice.
772 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)773 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
774                                                 const Dims<N>& input1_dims,
775                                                 NdArrayDesc<N>* desc0_out,
776                                                 NdArrayDesc<N>* desc1_out) {
777   TFLITE_DCHECK(desc0_out != nullptr);
778   TFLITE_DCHECK(desc1_out != nullptr);
779 
780   // Copy dims to desc.
781   for (int i = 0; i < N; ++i) {
782     desc0_out->extents[i] = input0_dims.sizes[i];
783     desc0_out->strides[i] = input0_dims.strides[i];
784     desc1_out->extents[i] = input1_dims.sizes[i];
785     desc1_out->strides[i] = input1_dims.strides[i];
786   }
787 
788   // Walk over each dimension. If the extents are equal do nothing.
789   // Otherwise, set the desc with extent 1 to have extent equal to the other and
790   // stride 0.
791   for (int i = 0; i < N; ++i) {
792     const int extent0 = ArraySize(input0_dims, i);
793     const int extent1 = ArraySize(input1_dims, i);
794     if (extent0 != extent1) {
795       if (extent0 == 1) {
796         desc0_out->strides[i] = 0;
797         desc0_out->extents[i] = extent1;
798       } else {
799         TFLITE_DCHECK_EQ(extent1, 1);
800         desc1_out->strides[i] = 0;
801         desc1_out->extents[i] = extent0;
802       }
803     }
804   }
805 }
806 
807 // Copies dims to desc, calculating strides.
808 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)809 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
810                            NdArrayDesc<N>* desc_out) {
811   int desc_stride = 1;
812   for (int i = N - 1; i >= 0; --i) {
813     desc_out->extents[i] = input_shape.Dims(i);
814     desc_out->strides[i] = desc_stride;
815     desc_stride *= input_shape.Dims(i);
816   }
817 }
818 
819 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)820 inline void NdArrayDescsForElementwiseBroadcast(
821     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
822     NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
823   TFLITE_DCHECK(desc0_out != nullptr);
824   TFLITE_DCHECK(desc1_out != nullptr);
825 
826   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
827   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
828 
829   // Copy dims to desc, calculating strides.
830   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
831   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
832 
833   // Walk over each dimension. If the extents are equal do nothing.
834   // Otherwise, set the desc with extent 1 to have extent equal to the other and
835   // stride 0.
836   for (int i = 0; i < N; ++i) {
837     const int extent0 = extended_input0_shape.Dims(i);
838     const int extent1 = extended_input1_shape.Dims(i);
839     if (extent0 != extent1) {
840       if (extent0 == 1) {
841         desc0_out->strides[i] = 0;
842         desc0_out->extents[i] = extent1;
843       } else {
844         TFLITE_DCHECK_EQ(extent1, 1);
845         desc1_out->strides[i] = 0;
846         desc1_out->extents[i] = extent0;
847       }
848     }
849   }
850 }
851 
852 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)853 inline void NdArrayDescsForElementwiseBroadcast(
854     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
855     const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
856     NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
857   TFLITE_DCHECK(desc0_out != nullptr);
858   TFLITE_DCHECK(desc1_out != nullptr);
859   TFLITE_DCHECK(desc2_out != nullptr);
860 
861   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
862   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
863   auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
864 
865   // Copy dims to desc, calculating strides.
866   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
867   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
868   CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
869 
870   // Walk over each dimension. If the extents are equal do nothing.
871   // Otherwise, set the desc with extent 1 to have extent equal to the other and
872   // stride 0.
873   for (int i = 0; i < N; ++i) {
874     const int extent0 = extended_input0_shape.Dims(i);
875     const int extent1 = extended_input1_shape.Dims(i);
876     const int extent2 = extended_input2_shape.Dims(i);
877 
878     int extent = extent0;
879     if (extent1 != 1) extent = extent1;
880     if (extent2 != 1) extent = extent2;
881 
882     TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
883     TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
884     TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
885 
886     if (!(extent0 == extent1 && extent1 == extent2)) {
887       if (extent0 == 1) {
888         desc0_out->strides[i] = 0;
889         desc0_out->extents[i] = extent;
890       }
891       if (extent1 == 1) {
892         desc1_out->strides[i] = 0;
893         desc1_out->extents[i] = extent;
894       }
895       if (extent2 == 1) {
896         desc2_out->strides[i] = 0;
897         desc2_out->extents[i] = extent;
898       }
899     }
900   }
901 }
902 
903 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
904 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
905 // re-writen as:
906 // for (int b = 0; b < output.extents[0]; ++b) {
907 //   for (int y = 0; y < output.extents[1]; ++y) {
908 //     for (int x = 0; x < output.extents[2]; ++x) {
909 //       for (int c = 0; c < output.extents[3]; ++c) {
910 //           calc({b,y,x,c});
911 //       }
912 //     }
913 //   }
914 // }
915 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])916 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
917     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
918   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
919     NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
920   }
921 }
922 
923 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])924 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
925     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
926   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
927     calc(indexes);
928   }
929 }
930 
931 // Execute the calc function in the innermost iteration based on the shape of
932 // the output. The calc function should take a single argument of type int[N].
933 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)934 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
935   int indexes[N] = {0};
936   NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
937 }
938 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
939 // gemmlowp.
940 //
941 // Returns the runtime argument rounded down to the nearest multiple of
942 // the fixed Modulus.
943 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)944 Integer RoundDown(Integer i) {
945   return i - (i % Modulus);
946 }
947 
948 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
949 // gemmlowp.
950 //
951 // Returns the runtime argument rounded up to the nearest multiple of
952 // the fixed Modulus.
953 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)954 Integer RoundUp(Integer i) {
955   return RoundDown<Modulus>(i + Modulus - 1);
956 }
957 
958 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
959 // gemmlowp.
960 //
961 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
962 template <typename Integer>
CeilQuotient(Integer a,Integer b)963 Integer CeilQuotient(Integer a, Integer b) {
964   return (a + b - 1) / b;
965 }
966 
967 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
968 // the direct dependency of internal/optimized/ on gemmlowp.
969 //
970 // It computes a reasonable number of threads to use for a GEMM of shape
971 // (rows, cols, depth).
972 //
973 // TODO(b/131910176): get rid of this function by switching each call site
974 // to its own more sensible logic for its own workload.
975 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)976 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
977                                 int depth) {
978   // Early-exit in the default case where multi-threading is disabled.
979   if (max_num_threads == 1) {
980     return 1;
981   }
982 
983   // Ensure that each thread has KernelRows rows to process, if at all possible.
984   int thread_count = std::min(max_num_threads, rows / KernelRows);
985 
986   // Limit the number of threads according to the overall size of the problem.
987   if (thread_count > 1) {
988     // Empirically determined value.
989     static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
990 
991     // We can only multiply two out of three sizes without risking overflow
992     const std::uint64_t cubic_size =
993         std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
994 
995     thread_count = std::min(
996         thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
997   }
998 
999   if (thread_count < 1) {
1000     thread_count = 1;
1001   }
1002 
1003   assert(thread_count > 0 && thread_count <= max_num_threads);
1004   return thread_count;
1005 }
1006 
1007 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1008 void optimized_ops_preload_l1_stream(const T* ptr) {
1009 #ifdef __GNUC__
1010   // builtin offered by GCC-compatible compilers including clang
1011   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1012 #else
1013   (void)ptr;
1014 #endif
1015 }
1016 
1017 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1018 void optimized_ops_preload_l1_keep(const T* ptr) {
1019 #ifdef __GNUC__
1020   // builtin offered by GCC-compatible compilers including clang
1021   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1022 #else
1023   (void)ptr;
1024 #endif
1025 }
1026 
1027 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1028 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1029 #ifdef __GNUC__
1030   // builtin offered by GCC-compatible compilers including clang
1031   __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1032 #else
1033   (void)ptr;
1034 #endif
1035 }
1036 
1037 }  // namespace tflite
1038 
1039 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1040