-
Notifications
You must be signed in to change notification settings - Fork 0
/
rbf_classification_som.py
133 lines (104 loc) · 3.93 KB
/
rbf_classification_som.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
This programs considers 4 3D gaussian distributions and creates train data
relative to the distributions. It uses SOM network to find centers of clusters
in train data. According to the centers found, RBF network will be trained.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn import datasets
from sklearn_som.som import SOM
import rbf_net
import rbf_layer as rbf
import torch.nn as nn
fig = None
ax1 = None
ax2 = None
ax3 = None
def get_constants():
max_epoch = 1500
each_turn = 75
max_loss = 0.02
learning_rate = 0.05
neuron_hidden_num = 12
centers_cluster = [[-1.5, -2.5], [-2, 2.3], [2.2, -2.2], [2.8, 2.5]]
return max_epoch, each_turn, max_loss, learning_rate, neuron_hidden_num, centers_cluster
def generate_data(centers):
X = Y = np.linspace(-5, 5, 30, endpoint=True)
X, Y = np.meshgrid(X, Y)
XY = np.column_stack([X.flat, Y.flat])
XY = np.around(XY, 1)
Z = func(XY, centers)
Z = Z.reshape(X.shape)
indexes, temp = datasets.make_blobs(n_samples=200, centers=centers, cluster_std=0.9, shuffle=False, random_state=10)
return X, Y, XY, Z, indexes
def func(xy, centers):
mu = np.array(centers[0])
covariance = np.diag(np.array([1.1, 1.1]) ** 2)
z = multivariate_normal.pdf(xy, mean=mu, cov=covariance)
mu = np.array(centers[1])
covariance = np.diag(np.array([1.4, 1.4]) ** 2)
z += multivariate_normal.pdf(xy, mean=mu, cov=covariance)
mu = np.array(centers[2])
covariance = np.diag(np.array([1.8, 1.8]) ** 2)
z += multivariate_normal.pdf(xy, mean=mu, cov=covariance)
mu = np.array(centers[3])
covariance = np.diag(np.array([1.3, 1.3]) ** 2)
z -= multivariate_normal.pdf(xy, mean=mu, cov=covariance)
return z
def present_som(indexes, centers):
global fig
global ax2
fig = plt.figure()
ax2 = fig.add_subplot(111)
plt.text(1.5, -6, 'Click anywhere to continue.')
ax2.scatter(indexes[:, 0], indexes[:, 1], linewidths=1, marker='.', color='lightcoral', label='x and y of train data')
ax2.scatter(centers[:, :, 0], centers[:, :, 1], color='black', marker='*', label='centers SOM found')
ax2.legend(loc='lower right')
plt.waitforbuttonpress()
plt.close(fig)
def present_initial(X, Y, Z):
global fig
global ax1
fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X, Y, Z, cmap='plasma')
def present(X, Y, predicted):
global fig
global ax3
if ax3: ax3.remove()
ax3 = fig.add_subplot(122, projection='3d')
ax3.plot_surface(X, Y, predicted, cmap='plasma')
fig.canvas.mpl_connect('motion_notify_event', on_move)
plt.pause(0.01)
def on_move(event):
if event.inaxes == ax1:
ax3.view_init(elev=ax1.elev, azim=ax1.azim)
elif event.inaxes == ax3:
ax1.view_init(elev=ax3.elev, azim=ax3.azim)
else:
return
fig.canvas.draw_idle()
def main():
max_epoch, each_turn, max_loss, learning_rate, neuron_hidden_num, centers_cluster = get_constants()
X, Y, XY, Z, indexes = generate_data(centers_cluster)
# find centers with SOM network
som = SOM(m=3, n=4, dim=2, random_state=20)
som.fit(np.array(indexes))
centers_som = som.cluster_centers_
present_som(indexes, centers_som)
present_initial(X, Y, Z)
basis_func = rbf.gaussian
net = rbf_net.RBFNet([2, 1], neuron_hidden_num, basis_func, c=centers_som)
for i in range(max_epoch//each_turn):
net.fit(torch.from_numpy(XY).float(), torch.from_numpy(Z.flatten()).float(), each_turn, learning_rate, nn.MSELoss())
net.eval()
with torch.no_grad():
prediction = net(torch.from_numpy(XY).float()).data.numpy()
present(X, Y, prediction.reshape(Z.shape))
print(f'\rEpoch {(i+1) * each_turn}/{max_epoch}', end='')
fig.canvas.mpl_connect('motion_notify_event', on_move)
plt.show()
if __name__ == '__main__':
main()