1 // -*- C++ -*- 2 /** @file */ 3 #pragma once 4 5 #ifdef DOXYGEN 6 #define ARM_MATH_MVEI 7 #define ARM_MATH_MVEF 8 #define ARM_MATH_MVE_FLOAT16 9 #endif 10 11 /** \addtogroup HELIUMALG 12 * @{ 13 */ 14 15 #if defined(ARM_MATH_MVE_FLOAT16) 16 17 /* 18 19 This can't be used with stride bigger than 21845 20 which for embedded is acceptable. 21 22 No check is done at runtime or build time that the stride is not 23 too big. 24 25 */ 26 27 template<typename MA, 28 typename MB, 29 typename RES, 30 typename std::enable_if< 31 has_vector_inst<MA>() && 32 SameElementType<MA,float16_t>::value,bool>::type = true> 33 __STATIC_INLINE void _arm_mat_mult_2x2_mve( 34 const MA &pSrcA, 35 const MB &pSrcB, 36 RES &&pDst) 37 { 38 using T = typename traits<MA>::Scalar; 39 //using ACC = typename vector_traits<T>::temp_accumulator; 40 using VEC = typename vector_traits<T>::vector; 41 42 const uint16_t offsetA[8] = { 0, 0, (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), 43 0, 0, (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride() }; 44 /* offsetB allows to read and duplicate 1 row of B */ 45 const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 }; 46 47 /* {d00, d01, d10, d11} */ 48 const uint16_t offsetD[8] = { 0, 1, (uint16_t)pDst.stride(), (uint16_t)(pDst.stride()+1), 49 0,0,0,0 }; 50 51 uint16x8_t vecOffsA, vecOffsB,vecOffsD; 52 VEC vecInA, vecInB, vecDst; 53 T *pOut = pDst.ptr(); /* output data matrix pointer */ 54 55 /* 56 * load initial offsets 57 */ 58 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 59 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 60 /* 61 * load {a00 a00 a10 a10 x x x x } 62 */ 63 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 64 /* 65 * load {b00 b01 b00 b01 x x x x } 66 */ 67 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 68 /* 69 * { a00 b00 a00 b01 70 * a10 b00 a10 b01 71 * x x 72 * x x } 73 */ 74 vecDst = vmulq(vecInA, vecInB); 75 /* 76 * move to 2nd column of matrix A 77 */ 78 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 79 /* 80 * load {a01 a01 a11 a11 x x x x} 81 */ 82 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 83 /* 84 * move to next B row 85 */ 86 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 87 /* 88 * load {b10, b11, b10, b11, x x x x } 89 */ 90 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 91 /* 92 * { a00 b00 + a01 b10 a00 b01 + a01 b11 93 * a10 b00 + a11 b10 a10 b01 + a11 b11 94 * x x 95 * x x } 96 */ 97 vecDst = vfmaq(vecDst, vecInA, vecInB); 98 99 mve_pred16_t p0 = vctp16q(2*2); 100 /* 101 * Store the result in the destination buffer 102 * (lower half of the vector) 103 */ 104 105 vecOffsD = vldrhq_u16((uint16_t const *) offsetD); 106 107 vstrhq_scatter_shifted_offset_p(pOut,vecOffsD,vecDst,p0); 108 109 } 110 111 112 template<typename MA, 113 typename MB, 114 typename RES, 115 typename std::enable_if< 116 has_vector_inst<MA>() && 117 SameElementType<MA,float16_t>::value,bool>::type = true> 118 __STATIC_INLINE void _arm_mat_mult_3x3_mve( 119 const MA &pSrcA, 120 const MB &pSrcB, 121 RES &&pDst) 122 { 123 const uint16_t offsetA[8] = { 0, 0, 0, 124 (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), 125 (uint16_t)(2U*pSrcA.stride()), (uint16_t)(2U*pSrcA.stride()) }; 126 /* offsetB allows to read and duplicate 1 row of B */ 127 const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 }; 128 const uint16_t offsetD[8] = { 0, 1, 2, 129 (uint16_t)(0+pDst.stride()), (uint16_t)(1+pDst.stride()), 130 (uint16_t)(2+pDst.stride()), 131 (uint16_t)(0+2*pDst.stride()), 132 (uint16_t)(1+2*pDst.stride()) }; 133 134 uint16x8_t vecOffsA, vecOffsB,vecOffsD; 135 float16x8_t vecInA, vecInB, vecDst; 136 float16_t *pOut = pDst.ptr(); /* output data matrix pointer */ 137 138 /* 139 * load initial offsets 140 */ 141 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 142 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 143 144 /* 145 * load {a00 a00 a00 a10 a10 a10 a20 a20} 146 */ 147 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 148 /* 149 * load {b00 b01 b02 b00 b01 b02 b00 b01} 150 */ 151 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 152 /* 153 * { a00 b00 a00 b01 a00 b02 154 * a10 b00 a10 b01 a10 b02 155 * a20 b00 a20 b01} 156 */ 157 vecDst = vmulq(vecInA, vecInB); 158 159 /* 160 * move to 2nd column of matrix A 161 */ 162 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 163 /* 164 * load {a01 a01 a01 a11 a11 a11 a21 a21} 165 */ 166 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 167 /* 168 * move to next B row 169 */ 170 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 171 /* 172 * load {b10, b11, b12, b10, b11, b12, b10, b11} 173 */ 174 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 175 /* 176 * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12 177 * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12 178 * a20 b00 + a21 b10 a20 b01 + a21 b11 } 179 */ 180 vecDst = vfmaq(vecDst, vecInA, vecInB); 181 /* 182 * move to 3rd column of matrix A 183 */ 184 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 185 /* 186 * load {a02 a02 a02 a12 a12 a12 a22 a22} 187 */ 188 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 189 /* 190 * move to next B row 191 */ 192 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 193 /* 194 * load {b20, b21, b22, b20, b21, b22, b20, b21} 195 */ 196 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 197 /* 198 * {a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22}, 199 * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22}, 200 * a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 } 201 */ 202 vecDst = vfmaq(vecDst, vecInA, vecInB); 203 204 /* 205 * Store the result in the destination buffer 206 */ 207 vecOffsD = vldrhq_u16((uint16_t const *) offsetD); 208 209 vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst); 210 211 pOut += 2*pDst.stride()+2; 212 213 /* last element computed in scalar mode 214 * a20 b02 + a21 b12 + a22 b22 215 */ 216 217 const _Float16 * pA = (const _Float16 *)pSrcA.const_ptr(); 218 const _Float16 * pB = (const _Float16 *)pSrcB.const_ptr(); 219 const index_t sa =pSrcA.stride(); 220 const index_t sb =pSrcB.stride(); 221 *pOut = pA[2*sa] * pB[2] + pA[1+2*sa] * pB[2+sb] + pA[2+2*sa] * pB[2+2*sb]; 222 223 } 224 225 226 227 template<typename MA, 228 typename MB, 229 typename RES, 230 typename std::enable_if< 231 has_vector_inst<MA>() && 232 SameElementType<MA,float16_t>::value,bool>::type = true> 233 __STATIC_INLINE void _arm_mat_mult_4x4_mve( 234 const MA &pSrcA, 235 const MB &pSrcB, 236 RES &&pDst) 237 { 238 /* offsetA allows to read and duplicate 2 successive column elements of A */ 239 const uint16_t offsetA[8] = { 0, 0, 0, 0, 240 (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride() }; 241 /* offsetB allows to read and duplicate 1 row of B */ 242 const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 }; 243 244 const uint16_t offsetD[8] = { 0, 1, 2, 3, 245 (uint16_t)(0+pDst.stride()), (uint16_t)(1+pDst.stride()), 246 (uint16_t)(2+pDst.stride()), (uint16_t)(3+pDst.stride()) }; 247 248 uint16x8_t vecOffsA, vecOffsB,vecOffsD; 249 float16x8_t vecInA, vecInB, vecDst0, vecDst1; 250 float16_t *pOut = pDst.ptr(); /* output data matrix pointer */ 251 252 /* 253 * load initial offsets 254 */ 255 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 256 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 257 258 /* 259 * load {a00 a00 a00 a00 a10 a10 a10 a10} 260 */ 261 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 262 /* 263 * load {b00 b01 b02 b03 b00 b01 b02 b03} 264 */ 265 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 266 267 /* 268 * { a00 b00 a00 b01 a00 b02 a00 b03 269 * a10 b00 a10 b01 a10 b02 a10 b03 } 270 */ 271 vecDst0 = vmulq(vecInA, vecInB); 272 /* 273 * jump 2 x A rows (2nd half of matrix) 274 */ 275 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2); 276 /* 277 * load {a20 a20 a20 a20 a30 a30 a30 a30} 278 */ 279 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 280 /* 281 * { a20 b00 a20 b01 a20 b02 a20 b03 282 * a30 b00 a30 b01 a30 b02 + a31 b12 } 283 */ 284 vecDst1 = vmulq(vecInA, vecInB); 285 /* 286 * rewind back to top half of the A matrix (2nd column) 287 */ 288 vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1)); 289 /* 290 * load {a01 a01 a01 a01 a11 a11 a11 a11} 291 */ 292 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 293 294 /* 295 * move to next B row 296 */ 297 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 298 /* 299 * load {b10, b11, b12, b13, b10, b11, b12, b13} 300 */ 301 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 302 /* 303 * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12 a00 b03 + a01 b13 304 * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12 a10 b03 + a11 b13 } 305 */ 306 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 307 /* 308 * jump 2 x A rows (2nd half of matrix) 309 */ 310 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2); 311 /* 312 * load {a21 a21 a21 a21 a31 a31 a31 a31} 313 */ 314 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 315 /* 316 * {a20 b00 + a21 b10 a20 b01 + a21 b11 a20 b02 + a21 b12 a20 b03 + a21 b13 317 * a30 b00 + a31 b10 a30 b01 + a31 b11 a30 b02 + a31 b12 a30 b03 + a31 b13 } 318 */ 319 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 320 321 /* 322 * rewind back to top half of the A matrix (3rd column) 323 */ 324 vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1)); 325 /* 326 * load {a02 a02 a02 a02 a12 a12 a12 a12} 327 */ 328 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 329 /* 330 * move to next B row 331 */ 332 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 333 /* 334 * load {b20, b21, b22, b23, b20, b21, b22, b23} 335 */ 336 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 337 /* 338 * { a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22 a00 b03 + a01 b13 + a02 b23 339 * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22 a10 b03 + a11 b13 + a12 b23 } 340 */ 341 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 342 /* 343 * jump 2 x A rows 344 */ 345 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 2*pSrcA.stride()); 346 347 /* 348 * load {a22 a22 a22 a22 a32 a32 a32 a32} 349 */ 350 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 351 /* 352 * {a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 a20 b02 + a21 b12 + a22 b22 a20 b03 + a21 b13 + a22 b23 353 * a30 b00 + a31 b10 + a32 b20 a30 b01 + a31 b11 + a32 b21 a30 b02 + a31 b12 + a32 b22 a30 b03 + a31 b13 + a32 b23 } 354 */ 355 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 356 357 /* 358 * rewind back to top half of the A matrix (4th column) 359 */ 360 vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1)); 361 /* 362 * load {a03 a03 a03 a03 a13 a13 a13 a13} 363 */ 364 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 365 /* 366 * move to next B row 367 */ 368 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride()); 369 /* 370 * load {b30, b31, b32, b33, b30, b31, b32, b33} 371 */ 372 vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 373 /* 374 * { a00 b00 +...+ a03 b30, a00 b01 +...+ a03 b31, a00 b02 +...+ a03 b32, a00 b03 +...+ a03 b33 375 * a10 b00 +...+ a13 b30, a10 b01 +...+ a13 b31, a10 b02 +...+ a13 b32, a10 b03 +...+ a13 b33 } 376 */ 377 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 378 /* 379 * jump 2 x A rows 380 */ 381 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2); 382 /* 383 * load {a23 a23 a23 a23 a33 a33 a33 a33} 384 */ 385 vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 386 /* 387 * {a20 b00 +...+ a23 b30, a20 b01 +...+ a23 b31, a20 b02 +...+ a23 b32, a20 b03 +...+ a23 b33 388 * a30 b00 +...+ a33 b30, a30 b01 +...+ a33 b31, a30 b02 +...+ a33 b32, a30 b03 +...+ a33 b33 } 389 */ 390 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 391 392 /* 393 * Store the result in the destination buffer 394 */ 395 vecOffsD = vldrhq_u16((uint16_t const *) offsetD); 396 vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst0); 397 pOut += 2*pDst.stride(); 398 vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst1); 399 400 } 401 402 #endif 403 404 /*! @} */