1 // -*- C++ -*- 2 /** @file */ 3 #pragma once 4 5 #ifdef DOXYGEN 6 #define ARM_MATH_MVEI 7 #define ARM_MATH_MVEF 8 #define ARM_MATH_MVE_FLOAT16 9 #endif 10 11 /** \addtogroup HELIUMALG 12 * @{ 13 */ 14 15 template<typename MA, 16 typename MB, 17 typename RES, 18 typename std::enable_if< 19 has_vector_inst<MA>() && 20 SameElementType<MA,float32_t>::value,bool>::type = true> 21 __STATIC_INLINE void _arm_mat_mult_2x2_mve( 22 const MA &pSrcA, 23 const MB &pSrcB, 24 RES &&pDst) 25 { 26 using T = typename traits<MA>::Scalar; 27 //using ACC = typename vector_traits<T>::temp_accumulator; 28 using VEC = typename vector_traits<T>::vector; 29 30 /* {a00, a00, a10, a10} */ 31 const uint32_t offsetA0[4] = { 0, 0, pSrcA.stride(), pSrcA.stride() }; 32 /* {b00, b01, b00, b01} */ 33 const uint32_t offsetB0[4] = { 0, 1, 0, 1 }; 34 /* {a01, a01, a11, a11} */ 35 const uint32_t offsetA1[4] = { 1, 1, pSrcA.stride() + 1, pSrcA.stride() + 1 }; 36 /* {b10, b11, b10, b11} */ 37 const uint32_t offsetB1[4] = { pSrcB.stride(), pSrcB.stride()+1, pSrcB.stride(), pSrcB.stride()+1 }; 38 39 /* {d00, d01, d10, d11} */ 40 const uint32_t offsetD[4] = { 0, 1, pDst.stride(), pDst.stride()+1 }; 41 42 uint32x4_t vecOffsA, vecOffsB,vecOffsC; 43 VEC vecInA, vecInB, vecDst; 44 45 if constexpr (!HasStaticStride<MA>::value) 46 { 47 vecOffsA = vldrwq_u32((uint32_t const *) offsetA0); 48 } 49 vecOffsB = vldrwq_u32((uint32_t const *) offsetB0); 50 51 if constexpr (!HasStaticStride<MA>::value) 52 { 53 vecInA = vldrwq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 54 } 55 else 56 { 57 constexpr int s = StaticStride<MA>::value; 58 vecInA = inner::vload1_gen_stride<0, 0, s, s>::run(pSrcA.const_ptr()); 59 } 60 61 if constexpr (!HasStaticStride<MB>::value) 62 { 63 vecInB = vldrwq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 64 } 65 else 66 { 67 vecInB = inner::vload1_gen_stride<0, 1, 0, 1>::run(pSrcB.const_ptr()); 68 } 69 vecDst = inner::vmul(vecInA, vecInB); 70 71 if constexpr (!HasStaticStride<MA>::value) 72 { 73 vecOffsA = vldrwq_u32((uint32_t const *) offsetA1); 74 } 75 76 if constexpr (!HasStaticStride<MB>::value) 77 { 78 vecOffsB = vldrwq_u32((uint32_t const *) offsetB1); 79 } 80 81 if constexpr (!HasStaticStride<MA>::value) 82 { 83 vecInA = vldrwq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA); 84 } 85 else 86 { 87 constexpr int s = StaticStride<MA>::value; 88 vecInA = inner::vload1_gen_stride<1, 1, s+1, s+1>::run(pSrcA.const_ptr()); 89 90 } 91 92 if constexpr (!HasStaticStride<MB>::value) 93 { 94 vecInB = vldrwq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB); 95 } 96 else 97 { 98 constexpr int s = StaticStride<MB>::value; 99 vecInB = inner::vload1_gen_stride<s, s+1, s, s+1>::run(pSrcB.const_ptr()); 100 } 101 102 if constexpr (!HasStaticStride<RES>::value) 103 { 104 vecOffsC = vldrwq_u32((uint32_t const *) offsetD); 105 } 106 107 vecDst = inner::vmacc(vecDst, vecInA, vecInB); 108 109 //inner::vstore1<1>(pDst.ptr(), vecDst); 110 if constexpr (!HasStaticStride<RES>::value) 111 { 112 vstrwq_scatter_shifted_offset(pDst.ptr(),vecOffsC,vecDst); 113 } 114 else 115 { 116 constexpr int s = StaticStride<RES>::value; 117 inner::vstore1_gen_stride<0, 1, s, s+1>::run(pDst.ptr(),vecDst); 118 } 119 120 } 121 122 template<typename MA, 123 typename MB, 124 typename RES, 125 typename std::enable_if< 126 has_vector_inst<MA>() && 127 SameElementType<MA,float32_t>::value,bool>::type = true> 128 __STATIC_INLINE void _arm_mat_mult_3x3_mve( 129 const MA &pSrcA, 130 const MB &pSrcB, 131 RES &&pDst) 132 { 133 using T = typename traits<MA>::Scalar; 134 using ACC = typename vector_traits<T>::temp_accumulator; 135 using VEC = typename vector_traits<T>::vector; 136 T *pInB = pSrcB.ptr(); /* input data matrix pointer B */ 137 T *pInA = pSrcA.ptr(); /* input data matrix pointer A */ 138 T *pOut = pDst.ptr(); /* output data matrix pointer */ 139 T *pInA0, *pInA1, *pInA2; 140 ACC vecMac0, vecMac1, vecMac2; 141 VEC vecInB; 142 T const *pSrBVec; 143 144 pSrBVec = (float32_t const *) pInB; 145 146 pInA0 = pInA; 147 pInA1 = pInA0 + pSrcA.stride(); 148 pInA2 = pInA1 + pSrcA.stride(); 149 /* enable predication to disable last (4th) vector element */ 150 mve_pred16_t p0 = inner::vctpq<T>::mk(MATRIX_DIM3); 151 152 /* 153 * load {b0,0, b0,1, b0,2, 0} 154 */ 155 vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0); 156 pSrBVec += pSrcB.stride(); 157 158 vecMac0 = inner::vmul(vecInB, *pInA0++); 159 vecMac1 = inner::vmul(vecInB, *pInA1++); 160 vecMac2 = inner::vmul(vecInB, *pInA2++); 161 /* 162 * load {b1,0, b1,1, b1,2, 0} 163 */ 164 vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0); 165 pSrBVec += pSrcB.stride(); 166 167 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++); 168 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++); 169 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++); 170 /* 171 * load {b2,0, b2,1 , b2,2, 0} 172 */ 173 vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0); 174 pSrBVec += pSrcB.stride(); 175 176 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++); 177 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++); 178 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++); 179 180 /* partial vector stores */ 181 inner::vstore1_z<1>(pOut, vecMac0, MATRIX_DIM3,p0); 182 pOut += pDst.stride(); 183 inner::vstore1_z<1>(pOut, vecMac1, MATRIX_DIM3,p0); 184 pOut += pDst.stride(); 185 inner::vstore1_z<1>(pOut, vecMac2, MATRIX_DIM3,p0); 186 /* 187 * Return to application 188 */ 189 } 190 191 template<typename MA, 192 typename MB, 193 typename RES, 194 typename std::enable_if< 195 has_vector_inst<MA>() && 196 SameElementType<MA,float32_t>::value,bool>::type = true> 197 __STATIC_INLINE void _arm_mat_mult_4x4_mve( 198 const MA &pSrcA, 199 const MB &pSrcB, 200 RES &&pDst) 201 { 202 using T = typename traits<MA>::Scalar; 203 using ACC = typename vector_traits<T>::temp_accumulator; 204 using VEC = typename vector_traits<T>::vector; 205 T const *pSrBVec; 206 T *pInB = pSrcB.ptr(); /* input data matrix pointer B */ 207 T *pInA = pSrcA.ptr(); /* input data matrix pointer A */ 208 T *pOut = pDst.ptr(); /* output data matrix pointer */ 209 T *pInA0, *pInA1, *pInA2, *pInA3; 210 ACC vecMac0, vecMac1, vecMac2, vecMac3; 211 VEC vecInB; 212 213 pSrBVec = (float32_t const *) pInB; 214 215 pInA0 = pInA; 216 pInA1 = pInA0 + pSrcA.stride(); 217 pInA2 = pInA1 + pSrcA.stride(); 218 pInA3 = pInA2 + pSrcA.stride(); 219 /* 220 * load {b0,0, b0,1, b0,2, b0,3} 221 */ 222 vecInB = inner::vload1<1>(pSrBVec); 223 pSrBVec += pSrcB.stride(); 224 225 vecMac0 = inner::vmul(vecInB, *pInA0++); 226 vecMac1 = inner::vmul(vecInB, *pInA1++); 227 vecMac2 = inner::vmul(vecInB, *pInA2++); 228 vecMac3 = inner::vmul(vecInB, *pInA3++); 229 /* 230 * load {b1,0, b1,1, b1,2, b1,3} 231 */ 232 vecInB = inner::vload1<1>(pSrBVec); 233 pSrBVec += pSrcB.stride(); 234 235 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++); 236 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++); 237 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++); 238 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++); 239 /* 240 * load {b2,0, b2,1, b2,2, b2,3} 241 */ 242 vecInB = inner::vload1<1>(pSrBVec); 243 pSrBVec += pSrcB.stride(); 244 245 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++); 246 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++); 247 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++); 248 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++); 249 /* 250 * load {b3,0, b3,1, b3,2, b3,3} 251 */ 252 vecInB = inner::vload1<1>(pSrBVec); 253 pSrBVec += pSrcB.stride(); 254 255 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++); 256 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++); 257 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++); 258 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++); 259 260 inner::vstore1<1>(pOut, vecMac0); 261 pOut += pDst.stride(); 262 inner::vstore1<1>(pOut, vecMac1); 263 pOut += pDst.stride(); 264 inner::vstore1<1>(pOut, vecMac2); 265 pOut += pDst.stride(); 266 inner::vstore1<1>(pOut, vecMac3); 267 268 } 269 270 /*! @} */