1 /*
2 * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_transpose_conv_row_s8_s32
22 * Description: Transpose covolution help function.
23 *
24 * $Date: 22 Oct 2024
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * Computation of transposed convolution for one row of input into a rolling scratch buffer.
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_transpose_conv_row_s8_s32(const int8_t * lhs,const int8_t * rhs,int32_t * output_start,const int32_t output_index,const int32_t output_max,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t input_channels,const int32_t output_channels,const int32_t lhs_offset,const int32_t row_offset,const int32_t input_x,const int32_t stride_x,const int32_t skip_rows_top,const int32_t skip_rows_bottom)48 arm_cmsis_nn_status arm_nn_transpose_conv_row_s8_s32(const int8_t *lhs,
49 const int8_t *rhs,
50 int32_t *output_start,
51 const int32_t output_index,
52 const int32_t output_max,
53 const int32_t rhs_rows,
54 const int32_t rhs_cols,
55 const int32_t input_channels,
56 const int32_t output_channels,
57 const int32_t lhs_offset,
58 const int32_t row_offset,
59 const int32_t input_x,
60 const int32_t stride_x,
61 const int32_t skip_rows_top,
62 const int32_t skip_rows_bottom)
63 {
64
65 const int32_t skip_pre_rows = skip_rows_top * rhs_cols * input_channels;
66 const int32_t skip_post_rows = skip_rows_bottom * rhs_cols * input_channels;
67 const int32_t rhs_rows_count = rhs_rows - skip_rows_top - skip_rows_bottom;
68
69 int32_t input_count = input_x;
70 for (; input_count > 3; input_count -= 4)
71 {
72 const int8_t *rhs_ptr = rhs;
73
74 for (int32_t i_out_channel = 0; i_out_channel < output_channels; i_out_channel++)
75 {
76 rhs_ptr += skip_pre_rows;
77 int32_t index = output_index;
78
79 for (int32_t i_row = 0; i_row < rhs_rows_count; i_row++)
80 {
81 int32_t *output_ptr0 = output_start + index;
82
83 for (int32_t i_col = 0; i_col < rhs_cols; i_col++)
84 {
85 const int8_t *lhs_ptr0 = lhs;
86
87 int32_t result0 = 0;
88 int32_t result1 = 0;
89 int32_t result2 = 0;
90 int32_t result3 = 0;
91
92 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
93 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
94 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
95
96 for (int32_t channel_count = input_channels; channel_count > 3; channel_count -= 4)
97 {
98 const int8_t *lhs_temp = lhs_ptr0;
99 int32_t lhs00 = arm_nn_read_s8x4(lhs_temp);
100 int32_t lhs01 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs00, 8);
101 lhs00 = SXTAB16(lhs_offset_s16x2, lhs00);
102 lhs_temp += input_channels;
103 int32_t lhs10 = arm_nn_read_s8x4(lhs_temp);
104 int32_t lhs11 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs10, 8);
105 lhs10 = SXTAB16(lhs_offset_s16x2, lhs10);
106 lhs_temp += input_channels;
107 int32_t lhs20 = arm_nn_read_s8x4(lhs_temp);
108 int32_t lhs21 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs20, 8);
109 lhs20 = SXTAB16(lhs_offset_s16x2, lhs20);
110 lhs_temp += input_channels;
111 int32_t lhs30 = arm_nn_read_s8x4(lhs_temp);
112 int32_t lhs31 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs30, 8);
113 lhs30 = SXTAB16(lhs_offset_s16x2, lhs30);
114 lhs_ptr0 += 4;
115
116 int32_t rhs0 = arm_nn_read_s8x4(rhs_ptr);
117 int32_t rhs1 = SXTB16_RORn((uint32_t)rhs0, 8);
118 rhs0 = SXTB16(rhs0);
119 rhs_ptr += 4;
120
121 result0 = SMLAD(lhs00, rhs0, result0);
122 result0 = SMLAD(lhs01, rhs1, result0);
123 result1 = SMLAD(lhs10, rhs0, result1);
124 result1 = SMLAD(lhs11, rhs1, result1);
125 result2 = SMLAD(lhs20, rhs0, result2);
126 result2 = SMLAD(lhs21, rhs1, result2);
127 result3 = SMLAD(lhs30, rhs0, result3);
128 result3 = SMLAD(lhs31, rhs1, result3);
129 }
130
131 for (int32_t i = 0; i < (input_channels & 0b11); i++)
132 {
133 const int8_t *lhs_temp = lhs_ptr0;
134 const int32_t lhs_val00 = *lhs_temp + lhs_offset;
135 lhs_temp += input_channels;
136 const int32_t lhs_val10 = *lhs_temp + lhs_offset;
137 lhs_temp += input_channels;
138 const int32_t lhs_val20 = *lhs_temp + lhs_offset;
139 lhs_temp += input_channels;
140 const int32_t lhs_val30 = *lhs_temp + lhs_offset;
141 lhs_ptr0++;
142
143 const int32_t rhs_val0 = *rhs_ptr++;
144
145 result0 += lhs_val00 * rhs_val0;
146 result1 += lhs_val10 * rhs_val0;
147 result2 += lhs_val20 * rhs_val0;
148 result3 += lhs_val30 * rhs_val0;
149 }
150
151 int32_t *output_temp = output_ptr0;
152 *output_ptr0 += result0;
153 output_temp += stride_x * output_channels;
154 *output_temp += result1;
155 output_temp += stride_x * output_channels;
156 *output_temp += result2;
157 output_temp += stride_x * output_channels;
158 *output_temp += result3;
159
160 output_ptr0 += output_channels;
161 #else
162 int32_t rhs_sum = 0;
163 #if defined(ARM_MATH_MVEI)
164
165 int channel_count = input_channels;
166 for (int channel_i = 0; channel_i < (input_channels + 15) / 16; channel_i++)
167 {
168 mve_pred16_t p0 = vctp8q((uint32_t)channel_count);
169 channel_count -= 16;
170
171 const int8_t *lhs_temp = lhs_ptr0;
172 int8x16_t v_lhs00 = vldrbq_z_s8(lhs_temp, p0);
173 lhs_temp += input_channels;
174 int8x16_t v_lhs10 = vldrbq_z_s8(lhs_temp, p0);
175 lhs_temp += input_channels;
176 int8x16_t v_lhs20 = vldrbq_z_s8(lhs_temp, p0);
177 lhs_temp += input_channels;
178 int8x16_t v_lhs30 = vldrbq_z_s8(lhs_temp, p0);
179
180 lhs_ptr0 += 16;
181 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, p0);
182 rhs_ptr += 16;
183
184 result0 = vmladavaq_s8(result0, v_lhs00, v_rhs0);
185 result1 = vmladavaq_s8(result1, v_lhs10, v_rhs0);
186 result2 = vmladavaq_s8(result2, v_lhs20, v_rhs0);
187 result3 = vmladavaq_s8(result3, v_lhs30, v_rhs0);
188
189 rhs_sum = vaddvaq_s8(rhs_sum, v_rhs0);
190 }
191
192 rhs_ptr += channel_count;
193
194 #else
195 for (int32_t channel_count = 0; channel_count < input_channels / 2; channel_count++)
196 {
197 const int8_t *lhs_temp = lhs_ptr0;
198 const int32_t lhs_val00 = *lhs_temp;
199 lhs_temp += input_channels;
200 const int32_t lhs_val10 = *lhs_temp;
201 lhs_temp += input_channels;
202 const int32_t lhs_val20 = *lhs_temp;
203 lhs_temp += input_channels;
204 const int32_t lhs_val30 = *lhs_temp;
205 lhs_ptr0++;
206
207 lhs_temp = lhs_ptr0;
208 const int32_t lhs_val01 = *lhs_temp;
209 lhs_temp += input_channels;
210 const int32_t lhs_val11 = *lhs_temp;
211 lhs_temp += input_channels;
212 const int32_t lhs_val21 = *lhs_temp;
213 lhs_temp += input_channels;
214 const int32_t lhs_val31 = *lhs_temp;
215 lhs_ptr0++;
216
217 const int32_t rhs_val0 = *rhs_ptr++;
218 const int32_t rhs_val1 = *rhs_ptr++;
219
220 result0 += lhs_val00 * rhs_val0;
221 result0 += lhs_val01 * rhs_val1;
222
223 result1 += lhs_val10 * rhs_val0;
224 result1 += lhs_val11 * rhs_val1;
225
226 result2 += lhs_val20 * rhs_val0;
227 result2 += lhs_val21 * rhs_val1;
228
229 result3 += lhs_val30 * rhs_val0;
230 result3 += lhs_val31 * rhs_val1;
231
232 rhs_sum += rhs_val0;
233 rhs_sum += rhs_val1;
234 }
235
236 // Input channel tail-handling
237 if (input_channels & 0b1)
238 {
239 const int8_t *lhs_temp = lhs_ptr0;
240 const int32_t lhs_val00 = *lhs_temp;
241 lhs_temp += input_channels;
242 const int32_t lhs_val10 = *lhs_temp;
243 lhs_temp += input_channels;
244 const int32_t lhs_val20 = *lhs_temp;
245 lhs_temp += input_channels;
246 const int32_t lhs_val30 = *lhs_temp;
247 lhs_ptr0++;
248
249 const int32_t rhs_val0 = *rhs_ptr++;
250
251 result0 += lhs_val00 * rhs_val0;
252 result1 += lhs_val10 * rhs_val0;
253 result2 += lhs_val20 * rhs_val0;
254 result3 += lhs_val30 * rhs_val0;
255
256 rhs_sum += rhs_val0;
257 }
258 #endif
259 int32_t *output_temp = output_ptr0;
260 *output_ptr0 += result0 + rhs_sum * lhs_offset;
261 output_temp += stride_x * output_channels;
262 *output_temp += result1 + rhs_sum * lhs_offset;
263 output_temp += stride_x * output_channels;
264 *output_temp += result2 + rhs_sum * lhs_offset;
265 output_temp += stride_x * output_channels;
266 *output_temp += result3 + rhs_sum * lhs_offset;
267
268 output_ptr0 += output_channels;
269 #endif
270 }
271
272 // Next row, wrapping around the circular buffer
273 index = (index + row_offset) % output_max;
274 }
275 // Next output_channel
276 ++output_start;
277 rhs_ptr += skip_post_rows;
278 }
279
280 output_start += (4 * stride_x - 1) * output_channels;
281 lhs += 4 * input_channels;
282 }
283
284 // Input column tail handling
285 if (input_count & 0b10)
286 {
287 const int8_t *rhs_ptr = rhs;
288
289 for (int32_t i_out_channel = 0; i_out_channel < output_channels; i_out_channel++)
290 {
291 int32_t index = output_index;
292 rhs_ptr += skip_pre_rows;
293
294 for (int32_t i_row = 0; i_row < rhs_rows_count; i_row++)
295 {
296 int32_t *output_ptr0 = output_start + index;
297
298 for (int32_t i_col = 0; i_col < rhs_cols; i_col++)
299 {
300 const int8_t *lhs_ptr0 = lhs;
301
302 int32_t result0 = 0;
303 int32_t result1 = 0;
304
305 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
306 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
307 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
308
309 for (int32_t channel_count = input_channels; channel_count > 3; channel_count -= 4)
310 {
311 const int8_t *lhs_temp = lhs_ptr0;
312 int32_t lhs00 = arm_nn_read_s8x4(lhs_temp);
313 int32_t lhs01 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs00, 8);
314 lhs00 = SXTAB16(lhs_offset_s16x2, lhs00);
315 lhs_temp += input_channels;
316 int32_t lhs10 = arm_nn_read_s8x4(lhs_temp);
317 int32_t lhs11 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs10, 8);
318 lhs10 = SXTAB16(lhs_offset_s16x2, lhs10);
319 lhs_ptr0 += 4;
320
321 int32_t rhs0 = arm_nn_read_s8x4(rhs_ptr);
322 int32_t rhs1 = SXTB16_RORn((uint32_t)rhs0, 8);
323 rhs0 = SXTB16(rhs0);
324 rhs_ptr += 4;
325
326 result0 = SMLAD(lhs00, rhs0, result0);
327 result0 = SMLAD(lhs01, rhs1, result0);
328 result1 = SMLAD(lhs10, rhs0, result1);
329 result1 = SMLAD(lhs11, rhs1, result1);
330 }
331
332 for (int32_t i = 0; i < (input_channels & 0b11); i++)
333 {
334 const int8_t *lhs_temp = lhs_ptr0;
335 const int32_t lhs_val00 = *lhs_temp + lhs_offset;
336 lhs_temp += input_channels;
337 const int32_t lhs_val10 = *lhs_temp + lhs_offset;
338 lhs_ptr0++;
339
340 const int32_t rhs_val0 = *rhs_ptr++;
341
342 result0 += lhs_val00 * rhs_val0;
343 result1 += lhs_val10 * rhs_val0;
344 }
345
346 int32_t *output_temp = output_ptr0;
347 *output_ptr0 += result0;
348 output_temp += stride_x * output_channels;
349 *output_temp += result1;
350
351 output_ptr0 += output_channels;
352 #else
353 int32_t rhs_sum = 0;
354 #if defined(ARM_MATH_MVEI)
355 int channel_count = input_channels;
356 for (int channel_i = 0; channel_i < (input_channels + 15) / 16; channel_i++)
357 {
358 mve_pred16_t p0 = vctp8q((uint32_t)channel_count);
359 channel_count -= 16;
360
361 const int8_t *lhs_temp = lhs_ptr0;
362 int8x16_t v_lhs00 = vldrbq_z_s8(lhs_temp, p0);
363 lhs_temp += input_channels;
364 int8x16_t v_lhs10 = vldrbq_z_s8(lhs_temp, p0);
365 lhs_ptr0 += 16;
366 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, p0);
367 rhs_ptr += 16;
368
369 result0 = vmladavaq_s8(result0, v_lhs00, v_rhs0);
370 result1 = vmladavaq_s8(result1, v_lhs10, v_rhs0);
371
372 rhs_sum = vaddvaq_s8(rhs_sum, v_rhs0);
373 }
374
375 rhs_ptr += channel_count;
376
377 #else
378 for (int32_t channel_count = 0; channel_count < input_channels; channel_count++)
379 {
380 const int8_t *lhs_temp = lhs_ptr0;
381 const int32_t lhs_val00 = *lhs_temp;
382 lhs_temp += input_channels;
383 const int32_t lhs_val10 = *lhs_temp;
384 lhs_ptr0++;
385
386 const int32_t rhs_val0 = *rhs_ptr++;
387
388 result0 += lhs_val00 * rhs_val0;
389 result1 += lhs_val10 * rhs_val0;
390
391 rhs_sum += rhs_val0;
392 }
393 #endif
394 int32_t *output_temp = output_ptr0;
395 *output_ptr0 += result0 + rhs_sum * lhs_offset;
396 output_temp += stride_x * output_channels;
397 *output_temp += result1 + rhs_sum * lhs_offset;
398
399 output_ptr0 += output_channels;
400 #endif
401 }
402
403 // Next row, wrapping around the circular buffer
404 index = (index + row_offset) % output_max;
405 }
406
407 // Next output_channel
408 ++output_start;
409 rhs_ptr += skip_post_rows;
410 }
411
412 output_start += (2 * stride_x - 1) * output_channels;
413 lhs += 2 * input_channels;
414 }
415
416 if (input_count & 0b1)
417 {
418 const int8_t *rhs_ptr = rhs;
419
420 for (int32_t i_out_channel = 0; i_out_channel < output_channels; i_out_channel++)
421 {
422 int32_t index = output_index;
423 rhs_ptr += skip_pre_rows;
424
425 for (int32_t i_row = 0; i_row < rhs_rows_count; i_row++)
426 {
427 int32_t *output_ptr0 = output_start + index;
428
429 for (int32_t i_col = 0; i_col < rhs_cols; i_col++)
430 {
431 const int8_t *lhs_ptr0 = lhs;
432
433 int32_t result0 = 0;
434 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
435 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
436 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
437
438 for (int32_t channel_count = input_channels; channel_count > 3; channel_count -= 4)
439 {
440 const int8_t *lhs_temp = lhs_ptr0;
441 int32_t lhs00 = arm_nn_read_s8x4(lhs_temp);
442 int32_t lhs01 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)lhs00, 8);
443 lhs00 = SXTAB16(lhs_offset_s16x2, lhs00);
444 lhs_ptr0 += 4;
445
446 int32_t rhs0 = arm_nn_read_s8x4(rhs_ptr);
447 int32_t rhs1 = SXTB16_RORn((uint32_t)rhs0, 8);
448 rhs0 = SXTB16(rhs0);
449 rhs_ptr += 4;
450
451 result0 = SMLAD(lhs00, rhs0, result0);
452 result0 = SMLAD(lhs01, rhs1, result0);
453 }
454
455 for (int32_t i = 0; i < (input_channels & 0b11); i++)
456 {
457 const int8_t *lhs_temp = lhs_ptr0;
458 const int32_t lhs_val00 = *lhs_temp + lhs_offset;
459 lhs_ptr0++;
460
461 const int32_t rhs_val0 = *rhs_ptr++;
462
463 result0 += lhs_val00 * rhs_val0;
464 }
465 #else
466 #if defined(ARM_MATH_MVEI)
467 int channel_count = input_channels;
468 for (int channel_i = 0; channel_i < (input_channels + 15) / 16; channel_i++)
469 {
470 mve_pred16_t p0 = vctp8q((uint32_t)channel_count);
471 channel_count -= 16;
472
473 int8x16_t v_lhs00 = vldrbq_z_s8(lhs_ptr0, p0);
474 lhs_ptr0 += 16;
475 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, p0);
476 rhs_ptr += 16;
477
478 result0 = vmladavaq_s8(result0, v_lhs00, v_rhs0);
479
480 int32_t rhs_sum = vaddvaq_s8(0, v_rhs0);
481 result0 += rhs_sum * lhs_offset;
482 }
483
484 rhs_ptr += channel_count;
485 #else
486 for (int32_t channel_count = 0; channel_count < input_channels; channel_count++)
487 {
488 const int32_t lhs_val00 = *lhs_ptr0;
489 lhs_ptr0++;
490
491 const int32_t rhs_val0 = *rhs_ptr++;
492
493 result0 += (lhs_val00 + lhs_offset) * rhs_val0;
494 }
495 #endif
496 #endif
497 *output_ptr0 += result0;
498 output_ptr0 += output_channels;
499 }
500
501 // Next row, wrapping around the circular buffer
502 index = (index + row_offset) % output_max;
503 }
504
505 // Next output_channel
506 ++output_start;
507 rhs_ptr += skip_post_rows;
508 }
509 }
510 return ARM_CMSIS_NN_SUCCESS;
511 }
512
513 /**
514 * @} end of Doxygen group
515 */
516