1# Bug corrections for version 1.9 2import cmsisdsp as dsp 3import cmsisdsp.fixedpoint as f 4import numpy as np 5import colorama 6from colorama import init,Fore, Back, Style 7from numpy.testing import assert_allclose 8 9from numpy.linalg import norm 10 11import matplotlib 12import matplotlib as mpl 13import matplotlib.pyplot as plt 14 15init() 16 17def printTitle(s): 18 print("\n" + Fore.GREEN + Style.BRIGHT + s + Style.RESET_ALL) 19 20def printSubTitle(s): 21 print("\n" + Style.BRIGHT + s + Style.RESET_ALL) 22 23printTitle("DTW Window") 24 25printSubTitle("SAKOE_CHIBA_WINDOW") 26 27refWin1=np.array([[1, 1, 1, 0, 0], 28 [1, 1, 1, 1, 0], 29 [1, 1, 1, 1, 1], 30 [0, 1, 1, 1, 1], 31 [0, 0, 1, 1, 1], 32 [0, 0, 0, 1, 1], 33 [0, 0, 0, 0, 1], 34 [0, 0, 0, 0, 0], 35 [0, 0, 0, 0, 0], 36 [0, 0, 0, 0, 0]], dtype=np.int8) 37 38dtwWindow=np.zeros((10,5),dtype=np.int8) 39wsize=2 40status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SAKOE_CHIBA_WINDOW,wsize,dtwWindow) 41 42assert (w==refWin1).all() 43 44printSubTitle("SLANTED_BAND_WINDOW") 45 46refWin2=np.array([[1, 1, 0, 0, 0], 47 [1, 1, 0, 0, 0], 48 [1, 1, 1, 0, 0], 49 [0, 1, 1, 0, 0], 50 [0, 1, 1, 1, 0], 51 [0, 0, 1, 1, 0], 52 [0, 0, 1, 1, 1], 53 [0, 0, 0, 1, 1], 54 [0, 0, 0, 1, 1], 55 [0, 0, 0, 0, 1]], dtype=np.int8) 56 57dtwWindow=np.zeros((10,5),dtype=np.int8) 58wsize=1 59status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SLANTED_BAND_WINDOW,wsize,dtwWindow) 60 61assert (w==refWin2).all() 62 63 64printTitle("DTW Cost Matrix and DTW Distance") 65 66QUERY_LENGTH = 10 67TEMPLATE_LENGTH = 5 68 69query=np.array([ 0.08387197, 0.68082274, 1.06756417, 0.88914541, 0.42513398, -0.3259053, 70 -0.80934885, -0.90979435, -0.64026483, 0.06923695]) 71 72template=np.array([ 1.00000000e+00, 7.96326711e-04, -9.99998732e-01, -2.38897811e-03, 73 9.99994927e-01]) 74 75cols=np.array([1,2,3]) 76rows=np.array([10,11,12]) 77 78printSubTitle("Without a window") 79 80referenceCost=np.array([[0.91612804, 0.9992037 , 2.0830743 , 2.1693354 , 3.0854583 ], 81 [1.2353053 , 1.6792301 , 3.3600516 , 2.8525472 , 2.8076797 ], 82 [1.3028694 , 2.3696373 , 4.4372 , 3.9225004 , 2.875249 ], 83 [1.4137241 , 2.302073 , 4.1912174 , 4.814035 , 2.9860985 ], 84 [1.98859 , 2.2623994 , 3.6875322 , 4.115055 , 3.5609593 ], 85 [3.3144953 , 2.589101 , 3.2631946 , 3.586711 , 4.8868594 ], 86 [5.123844 , 3.3992462 , 2.9704008 , 3.7773607 , 5.5867043 ], 87 [7.0336385 , 4.309837 , 3.0606053 , 3.9680107 , 5.8778 ], 88 [8.673903 , 4.950898 , 3.420339 , 4.058215 , 5.698475 ], 89 [9.604667 , 5.0193386 , 4.489575 , 3.563591 , 4.494349 ]], 90 dtype=np.float32) 91 92referenceDistance = 0.2996232807636261 93 94# Each row is a new query 95a,b = np.meshgrid(template,query) 96distance=abs(a-b).astype(np.float32) 97 98status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,None) 99 100 101assert_allclose(referenceDistance,dtwDistance) 102assert_allclose(referenceCost,dtwMatrix) 103 104printSubTitle("Path") 105 106path=dsp.arm_dtw_path_f32(np.copy(dtwMatrix)) 107#print(path) 108pathMatrix=np.zeros(dtwMatrix.shape) 109for x in list(zip(path[0::2],path[1::2])): 110 pathMatrix[x] = 1 111 112 113fig, ax = plt.subplots() 114im = ax.imshow(pathMatrix,vmax=2.0) 115 116for i in range(QUERY_LENGTH): 117 for j in range(TEMPLATE_LENGTH): 118 text = ax.text(j, i, "%.1f" % dtwMatrix[i, j], 119 ha="center", va="center", color="w") 120fig.tight_layout() 121plt.show() 122 123printSubTitle("With a window") 124 125referenceDistance = 0.617099940776825 126referenceCost=np.array([[9.1612804e-01, 9.9920368e-01, np.NAN, np.NAN, 127 np.NAN], 128 [1.2353053e+00, 1.6792301e+00, np.NAN, np.NAN, 129 np.NAN], 130 [1.3028694e+00, 2.3696373e+00, 4.4372001e+00, np.NAN, 131 np.NAN], 132 [np.NAN, 3.0795674e+00, 4.9687119e+00, np.NAN, 133 np.NAN], 134 [np.NAN, 3.5039051e+00, 4.9290380e+00, 5.3565612e+00, 135 np.NAN], 136 [np.NAN, np.NAN, 4.8520918e+00, 5.1756082e+00, 137 np.NAN], 138 [np.NAN, np.NAN, 5.0427418e+00, 5.8497019e+00, 139 7.6590457e+00], 140 [np.NAN, np.NAN, np.NAN, 6.7571073e+00, 141 8.6668968e+00], 142 [np.NAN, np.NAN, np.NAN, 7.3949833e+00, 143 9.0352430e+00], 144 [np.NAN, np.NAN, np.NAN, np.NAN, 145 9.2564993e+00]], dtype=np.float32) 146 147 148status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,w) 149 150 151assert_allclose(referenceDistance,dtwDistance) 152assert_allclose(referenceCost[w==1],dtwMatrix[w==1]) 153 154