1 // -*- C++ -*-
2 /** @file */
3 #pragma once
4 
5 #ifdef DOXYGEN
6 #define ARM_MATH_MVEI
7 #define ARM_MATH_MVEF
8 #define ARM_MATH_MVE_FLOAT16
9 #endif
10 
11 /** \addtogroup HELIUMALG
12  *  @{
13  */
14 
15 template<typename MA,
16          typename MB,
17          typename RES,
18 typename std::enable_if<
19          has_vector_inst<MA>() &&
20          SameElementType<MA,float32_t>::value,bool>::type = true>
21 __STATIC_INLINE  void _arm_mat_mult_2x2_mve(
22     const MA &pSrcA,
23     const MB &pSrcB,
24     RES &&pDst)
25 {
26     using T = typename traits<MA>::Scalar;
27     //using ACC = typename vector_traits<T>::temp_accumulator;
28     using VEC = typename vector_traits<T>::vector;
29 
30     /* {a00, a00, a10, a10} */
31     const uint32_t  offsetA0[4] = { 0, 0, pSrcA.stride(), pSrcA.stride() };
32     /* {b00, b01, b00, b01} */
33     const uint32_t  offsetB0[4] = { 0, 1, 0, 1 };
34     /* {a01, a01, a11, a11} */
35     const uint32_t  offsetA1[4] = { 1, 1, pSrcA.stride() + 1, pSrcA.stride() + 1 };
36     /* {b10, b11, b10, b11} */
37     const uint32_t  offsetB1[4] = { pSrcB.stride(), pSrcB.stride()+1, pSrcB.stride(), pSrcB.stride()+1 };
38 
39     /* {d00, d01, d10, d11} */
40     const uint32_t  offsetD[4] = { 0, 1, pDst.stride(), pDst.stride()+1 };
41 
42     uint32x4_t vecOffsA, vecOffsB,vecOffsC;
43     VEC vecInA, vecInB, vecDst;
44 
45     if constexpr (!HasStaticStride<MA>::value)
46     {
47        vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
48     }
49     vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
50 
51     if constexpr (!HasStaticStride<MA>::value)
52     {
53        vecInA = vldrwq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
54     }
55     else
56     {
57         constexpr int s = StaticStride<MA>::value;
58         vecInA = inner::vload1_gen_stride<0, 0, s, s>::run(pSrcA.const_ptr());
59     }
60 
61     if constexpr (!HasStaticStride<MB>::value)
62     {
63         vecInB = vldrwq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
64     }
65     else
66     {
67         vecInB = inner::vload1_gen_stride<0, 1, 0, 1>::run(pSrcB.const_ptr());
68     }
69     vecDst = inner::vmul(vecInA, vecInB);
70 
71     if constexpr (!HasStaticStride<MA>::value)
72     {
73         vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
74     }
75 
76     if constexpr (!HasStaticStride<MB>::value)
77     {
78         vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
79     }
80 
81     if constexpr (!HasStaticStride<MA>::value)
82     {
83        vecInA = vldrwq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
84     }
85     else
86     {
87         constexpr int s = StaticStride<MA>::value;
88         vecInA = inner::vload1_gen_stride<1, 1, s+1, s+1>::run(pSrcA.const_ptr());
89 
90     }
91 
92     if constexpr (!HasStaticStride<MB>::value)
93     {
94         vecInB = vldrwq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
95     }
96     else
97     {
98         constexpr int s = StaticStride<MB>::value;
99         vecInB = inner::vload1_gen_stride<s, s+1, s, s+1>::run(pSrcB.const_ptr());
100     }
101 
102     if constexpr (!HasStaticStride<RES>::value)
103     {
104         vecOffsC = vldrwq_u32((uint32_t const *) offsetD);
105     }
106 
107     vecDst = inner::vmacc(vecDst, vecInA, vecInB);
108 
109     //inner::vstore1<1>(pDst.ptr(), vecDst);
110     if constexpr (!HasStaticStride<RES>::value)
111     {
112        vstrwq_scatter_shifted_offset(pDst.ptr(),vecOffsC,vecDst);
113     }
114     else
115     {
116         constexpr int s = StaticStride<RES>::value;
117         inner::vstore1_gen_stride<0, 1, s, s+1>::run(pDst.ptr(),vecDst);
118     }
119 
120 }
121 
122 template<typename MA,
123          typename MB,
124          typename RES,
125 typename std::enable_if<
126          has_vector_inst<MA>() &&
127          SameElementType<MA,float32_t>::value,bool>::type = true>
128 __STATIC_INLINE  void _arm_mat_mult_3x3_mve(
129     const MA &pSrcA,
130     const MB &pSrcB,
131     RES &&pDst)
132 {
133     using T = typename traits<MA>::Scalar;
134     using ACC = typename vector_traits<T>::temp_accumulator;
135     using VEC = typename vector_traits<T>::vector;
136     T   *pInB = pSrcB.ptr(); /* input data matrix pointer B */
137     T   *pInA = pSrcA.ptr(); /* input data matrix pointer A  */
138     T   *pOut = pDst.ptr();  /* output data matrix pointer */
139     T   *pInA0, *pInA1, *pInA2;
140     ACC    vecMac0, vecMac1, vecMac2;
141     VEC    vecInB;
142     T const *pSrBVec;
143 
144     pSrBVec = (float32_t const *) pInB;
145 
146     pInA0 = pInA;
147     pInA1 = pInA0 + pSrcA.stride();
148     pInA2 = pInA1 + pSrcA.stride();
149     /* enable predication to disable last (4th) vector element */
150     mve_pred16_t p0 = inner::vctpq<T>::mk(MATRIX_DIM3);
151 
152     /*
153      * load {b0,0, b0,1, b0,2, 0}
154      */
155     vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0);
156     pSrBVec += pSrcB.stride();
157 
158     vecMac0 = inner::vmul(vecInB, *pInA0++);
159     vecMac1 = inner::vmul(vecInB, *pInA1++);
160     vecMac2 = inner::vmul(vecInB, *pInA2++);
161     /*
162      * load {b1,0, b1,1, b1,2, 0}
163      */
164     vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0);
165     pSrBVec += pSrcB.stride();
166 
167     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
168     vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
169     vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
170     /*
171      * load {b2,0, b2,1 , b2,2, 0}
172      */
173     vecInB = inner::vload1_z<1>(pSrBVec, MATRIX_DIM3,p0);
174     pSrBVec += pSrcB.stride();
175 
176     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
177     vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
178     vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
179 
180     /* partial vector stores */
181     inner::vstore1_z<1>(pOut, vecMac0, MATRIX_DIM3,p0);
182     pOut += pDst.stride();
183     inner::vstore1_z<1>(pOut, vecMac1, MATRIX_DIM3,p0);
184     pOut += pDst.stride();
185     inner::vstore1_z<1>(pOut, vecMac2, MATRIX_DIM3,p0);
186     /*
187      * Return to application
188      */
189 }
190 
191 template<typename MA,
192          typename MB,
193          typename RES,
194 typename std::enable_if<
195          has_vector_inst<MA>() &&
196          SameElementType<MA,float32_t>::value,bool>::type = true>
197 __STATIC_INLINE  void _arm_mat_mult_4x4_mve(
198     const MA &pSrcA,
199     const MB &pSrcB,
200     RES &&pDst)
201 {
202     using T = typename traits<MA>::Scalar;
203     using ACC = typename vector_traits<T>::temp_accumulator;
204     using VEC = typename vector_traits<T>::vector;
205     T const *pSrBVec;
206     T *pInB = pSrcB.ptr(); /* input data matrix pointer B */
207     T *pInA = pSrcA.ptr(); /* input data matrix pointer A  */
208     T *pOut = pDst.ptr();  /* output data matrix pointer */
209     T *pInA0, *pInA1, *pInA2, *pInA3;
210     ACC vecMac0, vecMac1, vecMac2, vecMac3;
211     VEC vecInB;
212 
213     pSrBVec = (float32_t const *) pInB;
214 
215     pInA0 = pInA;
216     pInA1 = pInA0 + pSrcA.stride();
217     pInA2 = pInA1 + pSrcA.stride();
218     pInA3 = pInA2 + pSrcA.stride();
219     /*
220      * load {b0,0, b0,1, b0,2, b0,3}
221      */
222     vecInB = inner::vload1<1>(pSrBVec);
223     pSrBVec += pSrcB.stride();
224 
225     vecMac0 = inner::vmul(vecInB, *pInA0++);
226     vecMac1 = inner::vmul(vecInB, *pInA1++);
227     vecMac2 = inner::vmul(vecInB, *pInA2++);
228     vecMac3 = inner::vmul(vecInB, *pInA3++);
229     /*
230      * load {b1,0, b1,1, b1,2, b1,3}
231      */
232     vecInB = inner::vload1<1>(pSrBVec);
233     pSrBVec += pSrcB.stride();
234 
235     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
236     vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
237     vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
238     vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
239     /*
240      * load {b2,0, b2,1, b2,2, b2,3}
241      */
242     vecInB = inner::vload1<1>(pSrBVec);
243     pSrBVec += pSrcB.stride();
244 
245     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
246     vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
247     vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
248     vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
249     /*
250      * load {b3,0, b3,1, b3,2, b3,3}
251      */
252     vecInB = inner::vload1<1>(pSrBVec);
253     pSrBVec += pSrcB.stride();
254 
255     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
256     vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
257     vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
258     vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
259 
260     inner::vstore1<1>(pOut, vecMac0);
261     pOut += pDst.stride();
262     inner::vstore1<1>(pOut, vecMac1);
263     pOut += pDst.stride();
264     inner::vstore1<1>(pOut, vecMac2);
265     pOut += pDst.stride();
266     inner::vstore1<1>(pOut, vecMac3);
267 
268 }
269 
270 /*! @} */