-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrbm.py
99 lines (77 loc) · 4.23 KB
/
rbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from layer import *
from network import NeuralNet #, sample_binary_stochastic
from numpy import *
class RBM(NeuralNet):
def __init__(self, numvis, numhid, vislayer=None, hidlayer=None, vishid=None):
'''Initialize an RBM with numvis visible units and numhid hidden units. The weights are randomly initialized
explicitly passed in as a parameter.'''
self.numvis = numvis
self.numhid = numhid
weights = [vishid] if vishid is not None else None
NeuralNet.__init__(self, [vislayer or BinaryStochasticLayer(numvis), hidlayer or BinaryStochasticLayer(numhid)], weights)
def get_vislayer(self):
return self.layers[0]
def get_hidlayer(self):
return self.layers[1]
def get_vishid(self):
return self.weights[0]
def sample_hid(self, data, prob=False):
'''Samples the hidden layer of the rbm given the parameter data as the state of the visibles'''
data = self.forward_pass(data, skip_layer=1)[1]
return self.get_hidlayer().probs if prob else data
def sample_vis(self, data, prob=False):
'''Samples the visible layer of the rbm given the parameter data as the state of the hiddens'''
data = self.backward_pass(data, skip_layer=1)[0]
return self.get_vislayer().probs if prob else data
def gibbs_given_h(self, data, K):
'''Performs K steps back and forth between hidden and visible starting from the parameter data as the state of the hiddens.
data is assumed to be the current activation of h.'''
hidact = data
visact = None
for k in range(K):
visact = self.sample_vis(hidact)
hidact = self.sample_hid(visact)
return visact, hidact
def gibbs_given_v(self, data, K):
'''Performs K steps back and forth between visible and hidden starting from the parameter data as the state of the visibles.
data is assumed to be the current activation of v'''
visact = data
for k in range(K):
hidact = self.sample_hid(visact)
visact = self.sample_vis(hidact)
return visact, hidact
def reconstruction_error(self, data, K=1):
self.gibbs_given_v(data, K)
visprobs = self.get_vislayer().probs
return square(data - visprobs).sum()
def train(self, data, K, learning_rate=0.05, bias_learn_rate =0.1, weightcost=0.0001):
'''Train the network using normalized data and CD-K for epochs epochs'''
assert self.numvis == data.shape[1], "Data does not match number of visible units."
#get the acivation probabilities for hidden units on each data case
self.get_vislayer().probs = data
hidact_data = self.sample_hid(data) # NxH matrix
hidprobs_data = self.get_hidlayer().probs
#compute the positive term in cd learning rule, Expected(sisj)_data
expect_pairact_data = dot(transpose(data), hidprobs_data)
#The same quantitiy for our biases, Expected(si)_data (i.e. bias unit is always 1)
expect_bias_hid_data = hidprobs_data.sum(0)
#expect_bias_vis_data = data.sum(0)
#now we get the logistic output after K steps of gibbs sampling and use that as probability of turning on
self.gibbs_given_h(hidact_data, K)
visprobs_cd, hidprobs_cd = self.get_vislayer().probs, self.get_hidlayer().probs
#now we compute the negative statistics for contrastive divergence, Expected(sisj)_model
expect_pairact_cd = dot(transpose(visprobs_cd), hidprobs_cd)
#again negative stats for learning the biases
expect_bias_hid_cd = hidprobs_cd.sum(0)
#expect_bias_vis_cd = visprobs_cd.sum(0)
recons_error = square(data - visprobs_cd).sum()
#learning time
N = float(data.shape[0])
delta_vishid = (learning_rate/N)*((expect_pairact_data - expect_pairact_cd) - weightcost*self.weights[0])
# delta_bias_vis += (bias_learn_rate/N)*(expect_bias_vis_data - expect_bias_vis_cd)
delta_bias_hid = (bias_learn_rate/N)*(expect_bias_hid_data - expect_bias_hid_cd)
self.weights[0] += delta_vishid
#self.layers[0].bias += delta_bias_vis
self.layers[1].bias += delta_bias_hid
#print 'Reconstruction Error:', recons_error
return recons_error