1from sklearn.naive_bayes import GaussianNB
2import random
3import numpy as np
4import math
5
6from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
7import matplotlib.pyplot as plt
8from matplotlib.font_manager import FontProperties
9
10# Generation of data to train the  classifier
11# 100 vectors are generated. Vector have dimension 2 so can be represented as points
12NBVECS = 100
13VECDIM = 2
14
15# 3 cluster of points are generated
16ballRadius = 1.0
17x1 = [1.5, 1] +  ballRadius * np.random.randn(NBVECS,VECDIM)
18x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
19x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM)
20
21# All points are concatenated
22X_train=np.concatenate((x1,x2,x3))
23
24# The classes are 0,1 and 2.
25Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS)))
26
27gnb = GaussianNB()
28gnb.fit(X_train, Y_train)
29
30print("Testing")
31y_pred = gnb.predict([[1.5,1.0]])
32print(y_pred)
33
34y_pred = gnb.predict([[-1.5,1.0]])
35print(y_pred)
36
37y_pred = gnb.predict([[0,-3.0]])
38print(y_pred)
39
40# Dump of data for CMSIS-DSP
41
42print("Parameters")
43# Gaussian averages
44print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_))))
45
46# Gaussian variances
47print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_))))
48
49# Class priors
50print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_))))
51
52print("Epsilon = ",gnb.epsilon_)
53
54
55# Some bounds are computed for the graphical representation
56x_min = X_train[:, 0].min()
57x_max = X_train[:, 0].max()
58y_min = X_train[:, 1].min()
59y_max = X_train[:, 1].max()
60
61font = FontProperties()
62font.set_size(20)
63
64r=plt.figure()
65plt.axis('off')
66plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font)
67plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font)
68plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font)
69scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00')
70scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600')
71scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE')
72#r.savefig('fig.jpeg')
73#plt.close(r)
74show()