1 // -*- C++ -*-
2 /** @file */
3 #pragma once
4 
5 /** \addtogroup SCALARALG
6  *  @{
7  */
8 
9 /**
10  * @brief      Matrix times matrix for scalar architecture and float
11  *
12  * @param[in]  pSrcA      The source a
13  * @param[in]  pSrcB      The source b
14  * @param      pDst       The destination
15  *
16  * @tparam     MA         Left hand side datatype
17  * @tparam     MB         Right hand side datatype
18  * @tparam     RES        Result datatype
19  * @tparam     <unnamed>  Check if float
20  */
21 template<typename MA,
22          typename MB,
23          typename RES,
24          typename std::enable_if<number_traits<typename traits<MA>::Scalar>::is_float,bool>::type = true>
_dot_m_m(const MA & pSrcA,const MB & pSrcB,RES && pDst,const Scalar * =nullptr)25 __STATIC_INLINE void _dot_m_m(const MA&pSrcA,const MB&pSrcB,
26                      RES &&pDst,
27                      const Scalar* = nullptr)
28 {
29   using T = typename traits<MA>::Scalar;
30   using Acc = typename number_traits<T>::accumulator;
31   //using Comp = typename number_traits<T>::compute_type;
32   T *pIn1 = pSrcA.ptr();                /* Input data matrix pointer A */
33   T *pIn2 = pSrcB.ptr();                /* Input data matrix pointer B */
34   T *pInA = pSrcA.ptr();                /* Input data matrix pointer A */
35   T *pInB = pSrcB.ptr();                /* Input data matrix pointer B */
36   T *pOut = pDst.ptr();                 /* Output data matrix pointer */
37   T *px;                                 /* Temporary output data matrix pointer */
38   Acc sum;                                 /* Accumulator */
39   uint16_t numRowsA = pSrcA.rows();            /* Number of rows of input matrix A */
40   uint16_t numColsB = pSrcB.columns();            /* Number of columns of input matrix B */
41   uint16_t numColsA = pSrcA.columns();            /* Number of columns of input matrix A */
42   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
43 
44 
45     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
46     /* row loop */
47     do
48     {
49       /* Output pointer is set to starting address of row being processed */
50       px = pOut + i;
51 
52       /* For every row wise process, column loop counter is to be initiated */
53       col = numColsB;
54 
55       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
56       pIn2 = pSrcB.ptr();
57 
58       /* column loop */
59       do
60       {
61         /* Set the variable sum, that acts as accumulator, to zero */
62         sum = Acc{};
63 
64         /* Initialize pointer pIn1 to point to starting address of column being processed */
65         pIn1 = pInA;
66 
67 
68         /* Loop unrolling: Compute 4 MACs at a time. */
69         colCnt = numColsA >> 2U;
70 
71         /* matrix multiplication */
72         while (colCnt > 0U)
73         {
74           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
75 
76           /* Perform the multiply-accumulates */
77           sum = inner::mac(sum, *pIn1++, *pIn2);
78           pIn2 += pSrcB.stride();
79 
80           sum = inner::mac(sum, *pIn1++, *pIn2);
81           pIn2 += pSrcB.stride();
82 
83           sum = inner::mac(sum, *pIn1++, *pIn2);
84           pIn2 += pSrcB.stride();
85 
86           sum = inner::mac(sum, *pIn1++, *pIn2);
87           pIn2 += pSrcB.stride();
88 
89           /* Decrement loop counter */
90           colCnt--;
91         }
92 
93         /* Loop unrolling: Compute remaining MACs */
94         colCnt = numColsA % 0x4U;
95 
96         while (colCnt > 0U)
97         {
98           /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
99 
100           /* Perform the multiply-accumulates */
101           sum = inner::mac(sum, *pIn1++, *pIn2);
102           pIn2 += pSrcB.stride();
103 
104           /* Decrement loop counter */
105           colCnt--;
106         }
107 
108         /* Store result in destination buffer */
109         *px++ = inner::from_accumulator(sum);
110 
111         /* Decrement column loop counter */
112         col--;
113 
114         /* Update pointer pIn2 to point to starting address of next column */
115         pIn2 = pInB + (numColsB - col);
116 
117       } while (col > 0U);
118 
119       /* Update pointer pInA to point to starting address of next row */
120       i = i + pDst.stride();
121       pInA = pInA + pSrcA.stride();
122 
123       /* Decrement row loop counter */
124       row--;
125 
126     } while (row > 0U);
127 
128 
129 }
130 
131 /*! @} */