1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23
24 #include <functional>
25
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30
31 namespace tflite {
32
33 constexpr int kReverseShift = -1;
34
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36 float* output_activation_min,
37 float* output_activation_max) {
38 switch (ac) {
39 case FusedActivationFunctionType::kNone:
40 *output_activation_min = std::numeric_limits<float>::lowest();
41 *output_activation_max = std::numeric_limits<float>::max();
42 break;
43 case FusedActivationFunctionType::kRelu:
44 *output_activation_min = 0.f;
45 *output_activation_max = std::numeric_limits<float>::max();
46 break;
47 case FusedActivationFunctionType::kRelu1:
48 *output_activation_min = -1.f;
49 *output_activation_max = 1.f;
50 break;
51 case FusedActivationFunctionType::kRelu6:
52 *output_activation_min = 0.f;
53 *output_activation_max = 6.f;
54 break;
55 }
56 }
57
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60 T output_activation_max) {
61 using std::max;
62 using std::min;
63 return min(max(x, output_activation_min), output_activation_max);
64 }
65
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69 float output_activation_min, output_activation_max;
70 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71 return ActivationFunctionWithMinMax(x, output_activation_min,
72 output_activation_max);
73 }
74
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76 const float* bias_data, int array_size,
77 float* array_data) {
78 // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79 // this with the Eigen one-liner:
80 // return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81 // This turned out to severely regress performance: +4ms (i.e. 8%) on
82 // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85 float* array_ptr = array_data;
86 float* array_end_ptr = array_ptr + array_size;
87 const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88 const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89 for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90 int i = 0;
91 for (; i <= bias_size - 16; i += 16) {
92 auto b0 = vld1q_f32(bias_data + i);
93 auto b1 = vld1q_f32(bias_data + i + 4);
94 auto b2 = vld1q_f32(bias_data + i + 8);
95 auto b3 = vld1q_f32(bias_data + i + 12);
96 auto a0 = vld1q_f32(array_ptr + i);
97 auto a1 = vld1q_f32(array_ptr + i + 4);
98 auto a2 = vld1q_f32(array_ptr + i + 8);
99 auto a3 = vld1q_f32(array_ptr + i + 12);
100 auto x0 = vaddq_f32(a0, b0);
101 auto x1 = vaddq_f32(a1, b1);
102 auto x2 = vaddq_f32(a2, b2);
103 auto x3 = vaddq_f32(a3, b3);
104 x0 = vmaxq_f32(clamp_min_vec, x0);
105 x1 = vmaxq_f32(clamp_min_vec, x1);
106 x2 = vmaxq_f32(clamp_min_vec, x2);
107 x3 = vmaxq_f32(clamp_min_vec, x3);
108 x0 = vminq_f32(clamp_max_vec, x0);
109 x1 = vminq_f32(clamp_max_vec, x1);
110 x2 = vminq_f32(clamp_max_vec, x2);
111 x3 = vminq_f32(clamp_max_vec, x3);
112 vst1q_f32(array_ptr + i, x0);
113 vst1q_f32(array_ptr + i + 4, x1);
114 vst1q_f32(array_ptr + i + 8, x2);
115 vst1q_f32(array_ptr + i + 12, x3);
116 }
117 for (; i <= bias_size - 4; i += 4) {
118 auto b = vld1q_f32(bias_data + i);
119 auto a = vld1q_f32(array_ptr + i);
120 auto x = vaddq_f32(a, b);
121 x = vmaxq_f32(clamp_min_vec, x);
122 x = vminq_f32(clamp_max_vec, x);
123 vst1q_f32(array_ptr + i, x);
124 }
125 for (; i < bias_size; i++) {
126 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127 clamp_min, clamp_max);
128 }
129 }
130 #else // not NEON
131 for (int array_offset = 0; array_offset < array_size;
132 array_offset += bias_size) {
133 for (int i = 0; i < bias_size; i++) {
134 array_data[array_offset + i] = ActivationFunctionWithMinMax(
135 array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136 }
137 }
138 #endif
139 }
140
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142 int32_t x, int32_t quantized_multiplier, int left_shift) {
143 using gemmlowp::RoundingDivideByPOT;
144 using gemmlowp::SaturatingRoundingDoublingHighMul;
145 return RoundingDivideByPOT(
146 SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150 int32_t x, int32_t quantized_multiplier, int left_shift) {
151 using gemmlowp::SaturatingRoundingDoublingHighMul;
152 return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153 quantized_multiplier);
154 }
155
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157 int32_t quantized_multiplier,
158 int shift) {
159 using gemmlowp::RoundingDivideByPOT;
160 using gemmlowp::SaturatingRoundingDoublingHighMul;
161 int left_shift = shift > 0 ? shift : 0;
162 int right_shift = shift > 0 ? 0 : -shift;
163 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164 x * (1 << left_shift), quantized_multiplier),
165 right_shift);
166 }
167
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169 int32_t quantized_multiplier,
170 int shift) {
171 // Inputs:
172 // - quantized_multiplier has fixed point at bit 31
173 // - shift is -31 to +7 (negative for right shift)
174 //
175 // Assumptions: The following input ranges are assumed
176 // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
177 // - scaling is chosen so final scaled result fits in int32_t
178 // - input x is in the range -(1<<47) <= x < (1<<47)
179 assert(quantized_multiplier >= 0);
180 assert(shift >= -31 && shift < 8);
181 assert(x >= -(static_cast<int64_t>(1) << 47) &&
182 x < (static_cast<int64_t>(1) << 47));
183
184 int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185 ? ((quantized_multiplier + (1 << 15)) >> 16)
186 : 0x7FFF;
187 int total_shift = 15 - shift;
188 x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189 int32_t result = x >> total_shift;
190 return result;
191 }
192
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196 int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197 const int left_shift = std::max(shift, 0);
198 const int right_shift = std::min(shift, 0);
199 int32x4x4_t result;
200
201 int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202 int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203 int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204
205 result.val[0] =
206 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207 multiplier_dup),
208 right_shift_dup);
209
210 result.val[1] =
211 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212 multiplier_dup),
213 right_shift_dup);
214
215 result.val[2] =
216 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217 multiplier_dup),
218 right_shift_dup);
219
220 result.val[3] =
221 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222 multiplier_dup),
223 right_shift_dup);
224
225 return result;
226 }
227 #endif
228
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231 static_assert(std::is_unsigned<T>::value,
232 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234 return integer_input ? __builtin_clz(integer_input)
235 : std::numeric_limits<T>::digits;
236 #else
237 if (integer_input == 0) {
238 return std::numeric_limits<T>::digits;
239 }
240
241 const T one_in_leading_positive = static_cast<T>(1)
242 << (std::numeric_limits<T>::digits - 1);
243 int leading_zeros = 0;
244 while (integer_input < one_in_leading_positive) {
245 integer_input <<= 1;
246 ++leading_zeros;
247 }
248 return leading_zeros;
249 #endif
250 }
251
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254 static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256 return integer_input ? __builtin_clrsb(integer_input)
257 : std::numeric_limits<T>::digits;
258 #else
259 using U = typename std::make_unsigned<T>::type;
260 return integer_input >= 0
261 ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262 : integer_input != std::numeric_limits<T>::min()
263 ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264 : 0;
265 #endif
266 }
267
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271 static_assert(std::is_integral<Integer>::value, "");
272 static_assert(std::is_signed<Integer>::value, "");
273 static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274 TFLITE_CHECK_GT(n, 0);
275 if (sizeof(Integer) == 4) {
276 return 30 - CountLeadingSignBits(n);
277 } else {
278 return 62 - CountLeadingSignBits(n);
279 }
280 }
281
282 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
283 // softmax
284 // func - the function to build the LUT for (e.g exp(x))
285 // min,max - table limits
286 // table - pointer to buffer
287 // num - number of elements in the LUT
gen_lut(double (* func)(double),double min,double max,int16_t * table,const int num)288 inline void gen_lut(double (*func)(double), double min, double max,
289 int16_t* table, const int num) {
290 // size of table should equal to num + 1
291 // last element only for slope calculation
292 double step = (max - min) / (num - 1);
293 double half_step = step / 2.0;
294 for (int i = 0; i < num - 1; i++) {
295 double sample_val = TfLiteRound(func(min + i * step) * 32768.0);
296 double midpoint_interp_val =
297 TfLiteRound((func(min + (i + 1) * step) * 32768.0 +
298 TfLiteRound(func(min + i * step) * 32768.0)) /
299 2.0);
300 double midpoint_val =
301 TfLiteRound(func(min + i * step + half_step) * 32768.0);
302 double midpoint_err = midpoint_interp_val - midpoint_val;
303 double bias = TfLiteRound(midpoint_err / 2.0);
304 table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
305 32767.0);
306 }
307 table[num - 1] = std::min<double>(
308 std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
309 }
310
311 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
312 // softmax
313 // func - the function to build the LUT for (e.g exp(x))
314 // min,max - table limits
315 // table - pointer to buffer
316 // num - number of elements in the LUT
gen_lut(float (* func)(float),float min,float max,int16_t * table,const int num)317 inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
318 const int num) {
319 // size of table should equal to num + 1
320 // last element only for slope calculation
321 float step = (max - min) / (num - 1);
322 float half_step = step / 2.0f;
323 for (int i = 0; i < num - 1; i++) {
324 float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
325 float midpoint_interp_val =
326 TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
327 TfLiteRound(func(min + i * step) * 32768.0f)) /
328 2.0f);
329 float midpoint_val =
330 TfLiteRound(func(min + i * step + half_step) * 32768.0f);
331 float midpoint_err = midpoint_interp_val - midpoint_val;
332 float bias = TfLiteRound(midpoint_err / 2.0f);
333 table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
334 32767.0f);
335 }
336 table[num - 1] = std::min<float>(
337 std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
338 }
339
340 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
generic_int16_table_lookup(int16_t value,const int16_t * lut)341 inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
342 // 512 base value, lut[513] only for calculate slope
343 uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
344 assert(index < 512 && "LUT index out of range.");
345 int16_t offset = value & 0x7f;
346
347 // base and slope are Q0.15
348 int16_t base = lut[index];
349 int16_t slope = lut[index + 1] - lut[index];
350
351 // Q0.15 * Q0.7 = Q0.22
352 // Round and convert from Q0.22 to Q0.15
353 int32_t delta = (static_cast<int32_t>(slope) * offset + 64) >> 7;
354
355 // Q0.15 + Q0.15
356 return base + delta;
357 }
358
359 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
360
361 // We use combined sigmoid and tanh look-up table, since
362 // tanh(x) = 2*sigmoid(2*x) -1.
363 // Both functions are symmetric, so the LUT table is only needed
364 // for the absolute value of the input.
365 static const uint16_t sigmoid_table_uint16[256] = {
366 32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
367 40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
368 46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
369 52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
370 56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
371 59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
372 61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
373 62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
374 63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
375 64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
376 64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
377 65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
378 65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
379 65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
380 65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
381 65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
382 65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
383 65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
384 65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
385 65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
386 65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
387 65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
388 65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
389 65534, 65534, 65535};
390
391 // TODO(b/77858996): Add these to gemmlowp.
392 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)393 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
394 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
395 return a;
396 }
397
398 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)399 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
400 std::int64_t a64 = a;
401 std::int64_t b64 = b;
402 std::int64_t sum = a64 + b64;
403 return static_cast<std::int32_t>(std::min(
404 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
405 std::max(
406 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
407 sum)));
408 }
409
410 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)411 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
412 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
413 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
414 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
415 SaturatingAddNonGemmlowp(a.raw(), b.raw()));
416 }
417
418 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)419 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
420 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
421 return a;
422 }
423
424 template <>
SaturatingSub(std::int16_t a,std::int16_t b)425 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
426 std::int32_t a32 = a;
427 std::int32_t b32 = b;
428 std::int32_t diff = a32 - b32;
429 return static_cast<std::int16_t>(
430 std::min(static_cast<int32_t>(32767),
431 std::max(static_cast<int32_t>(-32768), diff)));
432 }
433
434 template <>
SaturatingSub(std::int32_t a,std::int32_t b)435 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
436 std::int64_t a64 = a;
437 std::int64_t b64 = b;
438 std::int64_t diff = a64 - b64;
439 return static_cast<std::int32_t>(std::min(
440 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
441 std::max(
442 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
443 diff)));
444 }
445
446 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)447 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
448 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
449 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
450 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
451 SaturatingSub(a.raw(), b.raw()));
452 }
453 // End section to be moved to gemmlowp.
454
455 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)456 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
457 if (exponent == 0) {
458 return x;
459 }
460 using ScalarIntegerType =
461 typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
462 const IntegerType min =
463 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
464 const IntegerType max =
465 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
466 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
467
468 const std::int32_t threshold =
469 ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
470 const IntegerType positive_mask =
471 gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
472 const IntegerType negative_mask =
473 gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
474
475 IntegerType result = gemmlowp::ShiftLeft(x, exponent);
476 result = gemmlowp::SelectUsingMask(positive_mask, max, result);
477 result = gemmlowp::SelectUsingMask(negative_mask, min, result);
478 return result;
479 }
480
481 // If we want to leave IntegerBits fixed, then multiplication
482 // by a power of two has to be saturating/rounding, not exact anymore.
483 template <typename tRawType, int tIntegerBits>
484 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)485 SaturatingRoundingMultiplyByPOTParam(
486 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
487 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
488 SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
489 }
490
491 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)492 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
493 int16_t* multiplier_int16_t) {
494 TFLITE_DCHECK_GE(multiplier_int32_t, 0);
495 static constexpr int32_t kRoundingOffset = 1 << 15;
496 if (multiplier_int32_t >=
497 std::numeric_limits<int32_t>::max() - kRoundingOffset) {
498 *multiplier_int16_t = std::numeric_limits<int16_t>::max();
499 return;
500 }
501 const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
502 TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
503 TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
504 *multiplier_int16_t = result;
505 TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
506 }
507
508 // Minimum output bits to accommodate log of maximum input range. It actually
509 // does not matter if one considers, say, [-64,64] or [-64,64).
510 //
511 // For example, run this through Octave:
512 // [0:127; ...
513 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
514 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)515 constexpr int min_log_x_output_bits(int input_bits) {
516 return input_bits > 90 ? 7
517 : input_bits > 44 ? 6
518 : input_bits > 21 ? 5
519 : input_bits > 10 ? 4
520 : input_bits > 4 ? 3
521 : input_bits > 1 ? 2
522 : 1;
523 }
524
525 // Although currently the name of this function says that it cannot handle
526 // values less than 1, in practice it can handle as low as 1/x_max, where
527 // x_max is the largest representable input. In other words, the output range
528 // is symmetric.
529 template <int OutputIntegerBits, int InputIntegerBits>
530 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)531 log_x_for_x_greater_than_or_equal_to_1_impl(
532 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
533 // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
534 // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
535 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
536 // The reason for accumulating the result with an extra bit of headroom is
537 // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
538 // recip_denom will otherwise introduce an error.
539 static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
540 using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
541
542 const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
543 FixedPoint0, 1488522236, std::log(2.0));
544 const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
545 FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
546 const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
547 FixedPoint0, 1518500250, std::sqrt(0.5));
548 const FixedPoint0 one_quarter =
549 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
550
551 const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
552 FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
553 const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
554 FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
555 const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
556 FixedPoint0, 1057819769,
557 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
558 const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
559 FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
560
561 const FixedPointAccum shifted_quarter =
562 gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
563
564 // Reinterpret the input value as Q0.31, because we will figure out the
565 // required shift "ourselves" instead of using, say, Rescale.
566 FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
567 // z_a_pow_2 = input_integer_bits - z_a_headroom;
568 int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
569 FixedPoint0 r_a_tmp =
570 SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
571 const int32_t r_a_raw =
572 SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
573 // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
574 // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
575 // InputIntegerBits - z_b_headroom - 0.25);
576 const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
577 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
578 static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
579 31 - kAccumIntegerBits)),
580 shifted_quarter);
581
582 // z_b is treated like z_a, but premultiplying by sqrt(0.5).
583 FixedPoint0 z_b = z_a * sqrt_half;
584 int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
585 const int32_t r_b_raw =
586 SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
587 const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
588 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
589 static_cast<int32_t>(InputIntegerBits - z_b_headroom),
590 31 - kAccumIntegerBits)),
591 shifted_quarter);
592
593 const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
594 const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
595 std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
596
597 const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
598 FixedPoint0 q = r - sqrt_sqrt_half;
599 q = q + q;
600
601 const FixedPoint0 common_sq = q * q;
602 const FixedPoint0 num = q * r + q * common_sq * alpha_n;
603 const FixedPoint0 denom_minus_one_0 =
604 p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
605 const FixedPoint0 recip_denom =
606 one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
607
608 const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
609 return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
610 num_scaled * recip_denom);
611 }
612
613 template <int OutputIntegerBits, int InputIntegerBits>
614 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)615 log_x_for_x_greater_than_or_equal_to_1(
616 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
617 static_assert(
618 OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
619 "Output integer bits must be sufficient to accommodate logs of inputs.");
620 return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
621 InputIntegerBits>(
622 input_val);
623 }
624
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)625 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
626 int* num_bits_over_unit) {
627 int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
628 // This is the number of bits to the left of the binary point above 1.0.
629 // Consider x=1.25. In that case shifted_scale=0.8 and
630 // no later adjustment will be needed.
631 *num_bits_over_unit = x_integer_digits - headroom_plus_one;
632 const int32_t shifted_sum_minus_one =
633 static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
634 (static_cast<uint32_t>(1) << 31));
635
636 gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
637 gemmlowp::one_over_one_plus_x_for_x_in_0_1(
638 gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
639 return shifted_scale.raw();
640 }
641
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)642 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
643 int32_t* output_inv_sqrt,
644 int* output_shift) {
645 TFLITE_DCHECK_GE(input, 0);
646 if (input <= 1) {
647 // Handle the input value 1 separately to avoid overflow in that case
648 // in the general computation below (b/143972021). Also handle 0 as if it
649 // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
650 // but rare/unrealistic input value. We can expect both to occur in some
651 // incompletely trained models, but probably not in fully trained models.
652 *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
653 *output_shift = 0;
654 return;
655 }
656 TFLITE_DCHECK_GT(input, 1);
657 *output_shift = 11;
658 while (input >= (1 << 29)) {
659 input /= 4;
660 ++*output_shift;
661 }
662 const unsigned max_left_shift_bits =
663 CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
664 const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
665 const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
666 *output_shift -= left_shift_bit_pairs;
667 input <<= 2 * left_shift_bit_pairs;
668 TFLITE_DCHECK_GE(input, (1 << 27));
669 TFLITE_DCHECK_LT(input, (1 << 29));
670 using gemmlowp::FixedPoint;
671 using gemmlowp::Rescale;
672 using gemmlowp::SaturatingRoundingMultiplyByPOT;
673 // Using 3 integer bits gives us enough room for the internal arithmetic in
674 // this Newton-Raphson iteration.
675 using F3 = FixedPoint<int32_t, 3>;
676 using F0 = FixedPoint<int32_t, 0>;
677 const F3 fixedpoint_input = F3::FromRaw(input >> 1);
678 const F3 fixedpoint_half_input =
679 SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
680 const F3 fixedpoint_half_three =
681 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
682 // Newton-Raphson iteration
683 // Naive unoptimized starting guess: x = 1
684 F3 x = F3::One();
685 // Naive unoptimized number of iterations: 5
686 for (int i = 0; i < 5; i++) {
687 const F3 x3 = Rescale<3>(x * x * x);
688 x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
689 }
690 const F0 fixedpoint_half_sqrt_2 =
691 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
692 x = x * fixedpoint_half_sqrt_2;
693 *output_inv_sqrt = x.raw();
694 if (*output_shift < 0) {
695 *output_inv_sqrt <<= -*output_shift;
696 *output_shift = 0;
697 }
698 // Convert right shift (right is positive) to left shift.
699 *output_shift *= reverse_shift;
700 }
701
702 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
703 // BROADCASTING.
704 //
705 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
706 // rectangular array of numbers.
707 //
708 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
709 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
710 // to enable simple unoptimized implementations of element-wise broadcasting
711 // operations.
712 template <int N>
713 struct NdArrayDesc {
714 // The "extent" of each dimension. Indices along dimension d must be in the
715 // half-open interval [0, extents[d]).
716 int extents[N];
717
718 // The number of *elements* (not bytes) between consecutive indices of each
719 // dimension.
720 int strides[N];
721 };
722
723 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
724 // BROADCASTING.
725 //
726 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)727 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
728 int i3) {
729 TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
730 TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
731 TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
732 TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
733 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
734 i3 * desc.strides[3];
735 }
736
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])737 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
738 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
739 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
740 indexes[4] * desc.strides[4];
741 }
742
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])743 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
744 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
745 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
746 indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
747 indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
748 }
749
750 // Given the dimensions of the operands for an element-wise binary broadcast,
751 // adjusts them so that they can be directly iterated over with simple loops.
752 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
753 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
754 //
755 // This function assumes that the two input shapes are compatible up to
756 // broadcasting and the shorter one has already been prepended with 1s to be the
757 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
758 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
759 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
760 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
761 //
762 // When two shapes are compatible up to broadcasting, for each dimension d,
763 // the input extents are either equal, or one of them is 1.
764 //
765 // This function performs the following for each dimension d:
766 // - If the extents are equal, then do nothing since the loop that walks over
767 // both of the input arrays is correct.
768 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
769 // and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
770 // array0 to be referenced *at any index* in dimension d and still access the
771 // same slice.
772 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)773 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
774 const Dims<N>& input1_dims,
775 NdArrayDesc<N>* desc0_out,
776 NdArrayDesc<N>* desc1_out) {
777 TFLITE_DCHECK(desc0_out != nullptr);
778 TFLITE_DCHECK(desc1_out != nullptr);
779
780 // Copy dims to desc.
781 for (int i = 0; i < N; ++i) {
782 desc0_out->extents[i] = input0_dims.sizes[i];
783 desc0_out->strides[i] = input0_dims.strides[i];
784 desc1_out->extents[i] = input1_dims.sizes[i];
785 desc1_out->strides[i] = input1_dims.strides[i];
786 }
787
788 // Walk over each dimension. If the extents are equal do nothing.
789 // Otherwise, set the desc with extent 1 to have extent equal to the other and
790 // stride 0.
791 for (int i = 0; i < N; ++i) {
792 const int extent0 = ArraySize(input0_dims, i);
793 const int extent1 = ArraySize(input1_dims, i);
794 if (extent0 != extent1) {
795 if (extent0 == 1) {
796 desc0_out->strides[i] = 0;
797 desc0_out->extents[i] = extent1;
798 } else {
799 TFLITE_DCHECK_EQ(extent1, 1);
800 desc1_out->strides[i] = 0;
801 desc1_out->extents[i] = extent0;
802 }
803 }
804 }
805 }
806
807 // Copies dims to desc, calculating strides.
808 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)809 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
810 NdArrayDesc<N>* desc_out) {
811 int desc_stride = 1;
812 for (int i = N - 1; i >= 0; --i) {
813 desc_out->extents[i] = input_shape.Dims(i);
814 desc_out->strides[i] = desc_stride;
815 desc_stride *= input_shape.Dims(i);
816 }
817 }
818
819 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)820 inline void NdArrayDescsForElementwiseBroadcast(
821 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
822 NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
823 TFLITE_DCHECK(desc0_out != nullptr);
824 TFLITE_DCHECK(desc1_out != nullptr);
825
826 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
827 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
828
829 // Copy dims to desc, calculating strides.
830 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
831 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
832
833 // Walk over each dimension. If the extents are equal do nothing.
834 // Otherwise, set the desc with extent 1 to have extent equal to the other and
835 // stride 0.
836 for (int i = 0; i < N; ++i) {
837 const int extent0 = extended_input0_shape.Dims(i);
838 const int extent1 = extended_input1_shape.Dims(i);
839 if (extent0 != extent1) {
840 if (extent0 == 1) {
841 desc0_out->strides[i] = 0;
842 desc0_out->extents[i] = extent1;
843 } else {
844 TFLITE_DCHECK_EQ(extent1, 1);
845 desc1_out->strides[i] = 0;
846 desc1_out->extents[i] = extent0;
847 }
848 }
849 }
850 }
851
852 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)853 inline void NdArrayDescsForElementwiseBroadcast(
854 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
855 const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
856 NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
857 TFLITE_DCHECK(desc0_out != nullptr);
858 TFLITE_DCHECK(desc1_out != nullptr);
859 TFLITE_DCHECK(desc2_out != nullptr);
860
861 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
862 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
863 auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
864
865 // Copy dims to desc, calculating strides.
866 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
867 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
868 CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
869
870 // Walk over each dimension. If the extents are equal do nothing.
871 // Otherwise, set the desc with extent 1 to have extent equal to the other and
872 // stride 0.
873 for (int i = 0; i < N; ++i) {
874 const int extent0 = extended_input0_shape.Dims(i);
875 const int extent1 = extended_input1_shape.Dims(i);
876 const int extent2 = extended_input2_shape.Dims(i);
877
878 int extent = extent0;
879 if (extent1 != 1) extent = extent1;
880 if (extent2 != 1) extent = extent2;
881
882 TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
883 TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
884 TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
885
886 if (!(extent0 == extent1 && extent1 == extent2)) {
887 if (extent0 == 1) {
888 desc0_out->strides[i] = 0;
889 desc0_out->extents[i] = extent;
890 }
891 if (extent1 == 1) {
892 desc1_out->strides[i] = 0;
893 desc1_out->extents[i] = extent;
894 }
895 if (extent2 == 1) {
896 desc2_out->strides[i] = 0;
897 desc2_out->extents[i] = extent;
898 }
899 }
900 }
901 }
902
903 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
904 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
905 // re-writen as:
906 // for (int b = 0; b < output.extents[0]; ++b) {
907 // for (int y = 0; y < output.extents[1]; ++y) {
908 // for (int x = 0; x < output.extents[2]; ++x) {
909 // for (int c = 0; c < output.extents[3]; ++c) {
910 // calc({b,y,x,c});
911 // }
912 // }
913 // }
914 // }
915 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])916 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
917 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
918 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
919 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
920 }
921 }
922
923 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])924 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
925 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
926 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
927 calc(indexes);
928 }
929 }
930
931 // Execute the calc function in the innermost iteration based on the shape of
932 // the output. The calc function should take a single argument of type int[N].
933 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)934 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
935 int indexes[N] = {0};
936 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
937 }
938 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
939 // gemmlowp.
940 //
941 // Returns the runtime argument rounded down to the nearest multiple of
942 // the fixed Modulus.
943 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)944 Integer RoundDown(Integer i) {
945 return i - (i % Modulus);
946 }
947
948 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
949 // gemmlowp.
950 //
951 // Returns the runtime argument rounded up to the nearest multiple of
952 // the fixed Modulus.
953 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)954 Integer RoundUp(Integer i) {
955 return RoundDown<Modulus>(i + Modulus - 1);
956 }
957
958 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
959 // gemmlowp.
960 //
961 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
962 template <typename Integer>
CeilQuotient(Integer a,Integer b)963 Integer CeilQuotient(Integer a, Integer b) {
964 return (a + b - 1) / b;
965 }
966
967 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
968 // the direct dependency of internal/optimized/ on gemmlowp.
969 //
970 // It computes a reasonable number of threads to use for a GEMM of shape
971 // (rows, cols, depth).
972 //
973 // TODO(b/131910176): get rid of this function by switching each call site
974 // to its own more sensible logic for its own workload.
975 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)976 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
977 int depth) {
978 // Early-exit in the default case where multi-threading is disabled.
979 if (max_num_threads == 1) {
980 return 1;
981 }
982
983 // Ensure that each thread has KernelRows rows to process, if at all possible.
984 int thread_count = std::min(max_num_threads, rows / KernelRows);
985
986 // Limit the number of threads according to the overall size of the problem.
987 if (thread_count > 1) {
988 // Empirically determined value.
989 static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
990
991 // We can only multiply two out of three sizes without risking overflow
992 const std::uint64_t cubic_size =
993 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
994
995 thread_count = std::min(
996 thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
997 }
998
999 if (thread_count < 1) {
1000 thread_count = 1;
1001 }
1002
1003 assert(thread_count > 0 && thread_count <= max_num_threads);
1004 return thread_count;
1005 }
1006
1007 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1008 void optimized_ops_preload_l1_stream(const T* ptr) {
1009 #ifdef __GNUC__
1010 // builtin offered by GCC-compatible compilers including clang
1011 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1012 #else
1013 (void)ptr;
1014 #endif
1015 }
1016
1017 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1018 void optimized_ops_preload_l1_keep(const T* ptr) {
1019 #ifdef __GNUC__
1020 // builtin offered by GCC-compatible compilers including clang
1021 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1022 #else
1023 (void)ptr;
1024 #endif
1025 }
1026
1027 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1028 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1029 #ifdef __GNUC__
1030 // builtin offered by GCC-compatible compilers including clang
1031 __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1032 #else
1033 (void)ptr;
1034 #endif
1035 }
1036
1037 } // namespace tflite
1038
1039 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1040