1from sklearn.naive_bayes import GaussianNB 2import random 3import numpy as np 4import math 5 6from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy 7import matplotlib.pyplot as plt 8from matplotlib.font_manager import FontProperties 9 10# Generation of data to train the classifier 11# 100 vectors are generated. Vector have dimension 2 so can be represented as points 12NBVECS = 100 13VECDIM = 2 14 15# 3 cluster of points are generated 16ballRadius = 1.0 17x1 = [1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM) 18x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM) 19x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM) 20 21# All points are concatenated 22X_train=np.concatenate((x1,x2,x3)) 23 24# The classes are 0,1 and 2. 25Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS))) 26 27gnb = GaussianNB() 28gnb.fit(X_train, Y_train) 29 30print("Testing") 31y_pred = gnb.predict([[1.5,1.0]]) 32print(y_pred) 33 34y_pred = gnb.predict([[-1.5,1.0]]) 35print(y_pred) 36 37y_pred = gnb.predict([[0,-3.0]]) 38print(y_pred) 39 40# Dump of data for CMSIS-DSP 41 42print("Parameters") 43# Gaussian averages 44print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_)))) 45 46# Gaussian variances 47print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_)))) 48 49# Class priors 50print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_)))) 51 52print("Epsilon = ",gnb.epsilon_) 53 54 55# Some bounds are computed for the graphical representation 56x_min = X_train[:, 0].min() 57x_max = X_train[:, 0].max() 58y_min = X_train[:, 1].min() 59y_max = X_train[:, 1].max() 60 61font = FontProperties() 62font.set_size(20) 63 64r=plt.figure() 65plt.axis('off') 66plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font) 67plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font) 68plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font) 69scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00') 70scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600') 71scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE') 72#r.savefig('fig.jpeg') 73#plt.close(r) 74show()