1# Bug corrections for version 1.9
2import cmsisdsp as dsp
3import cmsisdsp.fixedpoint as f
4import numpy as np
5import colorama
6from colorama import init,Fore, Back, Style
7from numpy.testing import assert_allclose
8
9from numpy.linalg import norm
10
11import matplotlib
12import matplotlib as mpl
13import matplotlib.pyplot as plt
14
15init()
16
17def printTitle(s):
18    print("\n" + Fore.GREEN + Style.BRIGHT +  s + Style.RESET_ALL)
19
20def printSubTitle(s):
21    print("\n" + Style.BRIGHT + s + Style.RESET_ALL)
22
23printTitle("DTW Window")
24
25printSubTitle("SAKOE_CHIBA_WINDOW")
26
27refWin1=np.array([[1, 1, 1, 0, 0],
28       [1, 1, 1, 1, 0],
29       [1, 1, 1, 1, 1],
30       [0, 1, 1, 1, 1],
31       [0, 0, 1, 1, 1],
32       [0, 0, 0, 1, 1],
33       [0, 0, 0, 0, 1],
34       [0, 0, 0, 0, 0],
35       [0, 0, 0, 0, 0],
36       [0, 0, 0, 0, 0]], dtype=np.int8)
37
38dtwWindow=np.zeros((10,5),dtype=np.int8)
39wsize=2
40status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SAKOE_CHIBA_WINDOW,wsize,dtwWindow)
41
42assert (w==refWin1).all()
43
44printSubTitle("SLANTED_BAND_WINDOW")
45
46refWin2=np.array([[1, 1, 0, 0, 0],
47       [1, 1, 0, 0, 0],
48       [1, 1, 1, 0, 0],
49       [0, 1, 1, 0, 0],
50       [0, 1, 1, 1, 0],
51       [0, 0, 1, 1, 0],
52       [0, 0, 1, 1, 1],
53       [0, 0, 0, 1, 1],
54       [0, 0, 0, 1, 1],
55       [0, 0, 0, 0, 1]], dtype=np.int8)
56
57dtwWindow=np.zeros((10,5),dtype=np.int8)
58wsize=1
59status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SLANTED_BAND_WINDOW,wsize,dtwWindow)
60
61assert (w==refWin2).all()
62
63
64printTitle("DTW Cost Matrix and DTW Distance")
65
66QUERY_LENGTH = 10
67TEMPLATE_LENGTH = 5
68
69query=np.array([ 0.08387197,  0.68082274,  1.06756417,  0.88914541,  0.42513398, -0.3259053,
70 -0.80934885, -0.90979435, -0.64026483,  0.06923695])
71
72template=np.array([ 1.00000000e+00,  7.96326711e-04, -9.99998732e-01, -2.38897811e-03,
73  9.99994927e-01])
74
75cols=np.array([1,2,3])
76rows=np.array([10,11,12])
77
78printSubTitle("Without a window")
79
80referenceCost=np.array([[0.91612804, 0.9992037 , 2.0830743 , 2.1693354 , 3.0854583 ],
81       [1.2353053 , 1.6792301 , 3.3600516 , 2.8525472 , 2.8076797 ],
82       [1.3028694 , 2.3696373 , 4.4372    , 3.9225004 , 2.875249  ],
83       [1.4137241 , 2.302073  , 4.1912174 , 4.814035  , 2.9860985 ],
84       [1.98859   , 2.2623994 , 3.6875322 , 4.115055  , 3.5609593 ],
85       [3.3144953 , 2.589101  , 3.2631946 , 3.586711  , 4.8868594 ],
86       [5.123844  , 3.3992462 , 2.9704008 , 3.7773607 , 5.5867043 ],
87       [7.0336385 , 4.309837  , 3.0606053 , 3.9680107 , 5.8778    ],
88       [8.673903  , 4.950898  , 3.420339  , 4.058215  , 5.698475  ],
89       [9.604667  , 5.0193386 , 4.489575  , 3.563591  , 4.494349  ]],
90      dtype=np.float32)
91
92referenceDistance = 0.2996232807636261
93
94# Each row is a new query
95a,b = np.meshgrid(template,query)
96distance=abs(a-b).astype(np.float32)
97
98status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,None)
99
100
101assert_allclose(referenceDistance,dtwDistance)
102assert_allclose(referenceCost,dtwMatrix)
103
104printSubTitle("Path")
105
106path=dsp.arm_dtw_path_f32(np.copy(dtwMatrix))
107#print(path)
108pathMatrix=np.zeros(dtwMatrix.shape)
109for x in list(zip(path[0::2],path[1::2])):
110    pathMatrix[x] = 1
111
112
113fig, ax = plt.subplots()
114im = ax.imshow(pathMatrix,vmax=2.0)
115
116for i in range(QUERY_LENGTH):
117    for j in range(TEMPLATE_LENGTH):
118        text = ax.text(j, i, "%.1f" % dtwMatrix[i, j],
119                       ha="center", va="center", color="w")
120fig.tight_layout()
121plt.show()
122
123printSubTitle("With a window")
124
125referenceDistance = 0.617099940776825
126referenceCost=np.array([[9.1612804e-01, 9.9920368e-01, np.NAN, np.NAN,
127        np.NAN],
128       [1.2353053e+00, 1.6792301e+00, np.NAN, np.NAN,
129        np.NAN],
130       [1.3028694e+00, 2.3696373e+00, 4.4372001e+00, np.NAN,
131        np.NAN],
132       [np.NAN, 3.0795674e+00, 4.9687119e+00, np.NAN,
133        np.NAN],
134       [np.NAN, 3.5039051e+00, 4.9290380e+00, 5.3565612e+00,
135        np.NAN],
136       [np.NAN, np.NAN, 4.8520918e+00, 5.1756082e+00,
137        np.NAN],
138       [np.NAN, np.NAN, 5.0427418e+00, 5.8497019e+00,
139        7.6590457e+00],
140       [np.NAN, np.NAN, np.NAN, 6.7571073e+00,
141        8.6668968e+00],
142       [np.NAN, np.NAN, np.NAN, 7.3949833e+00,
143        9.0352430e+00],
144       [np.NAN, np.NAN, np.NAN, np.NAN,
145        9.2564993e+00]], dtype=np.float32)
146
147
148status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,w)
149
150
151assert_allclose(referenceDistance,dtwDistance)
152assert_allclose(referenceCost[w==1],dtwMatrix[w==1])
153
154