Skip to content

Commit

Permalink
- Update to DCLL July 14/2021
Browse files Browse the repository at this point in the history
  • Loading branch information
Emre committed Jul 14, 2021
1 parent 8e092e8 commit 61aca54
Show file tree
Hide file tree
Showing 19 changed files with 1,864 additions and 797 deletions.
318 changes: 203 additions & 115 deletions decolle/base_model.py

Large diffs are not rendered by default.

75 changes: 75 additions & 0 deletions decolle/init_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/bin/python
#-----------------------------------------------------------------------------
# File Name : init_functions.py
# Author: Emre Neftci
#
# Creation Date : Fri 26 Feb 2021 11:48:40 AM PST
# Last Modified :
#
# Copyright : (c) UC Regents, Emre Neftci
# Licence : GPLv2
#-----------------------------------------------------------------------------
import torch
import numpy as np

from torch.nn import init



def init_LSUV(net, data_batch, mu=0.0, var=1.0):
'''
Initialization inspired from Mishkin D and Matas J. All you need is a good init. arXiv:1511.06422 [cs],
February 2016.
'''
##Initialize
if mu is None:
mu = 0.0
if var is None:
var = 1.0
with torch.no_grad():
net.init_parameters(data_batch)
#def lsuv(net, data_batch):
for l in net.LIF_layers:
l.base_layer.bias.data *= 0
init.orthogonal_(l.base_layer.weight)

if hasattr(l,'rec_layer'):
l.rec_layer.bias.data *= 0
init.orthogonal_(l.rec_layer.weight)
alldone = False
while not alldone:
alldone = True
s,r,u = net.process_output(data_batch)
for i in range(len(net)):
v=np.var(u[i][-1].flatten())
m=np.mean(u[i][-1].flatten())
mus=np.mean(s[i][-1].flatten())
print(i,v,m,mus)
if np.isnan(v) or np.isnan(m):
print('Nan encountered during init')
mus = -.1
if np.abs(v-var)>.1:
net.LIF_layers[i].base_layer.weight.data /= np.sqrt(v)*np.sqrt(var)
## Won't converge:
#if hasattr(net.LIF_layers[i],'rec_layer'):
# net.LIF_layers[i].rec_layer.weight.data /= np.sqrt(v)*np.sqrt(var)
done=False
else:
done=True

if np.abs(m-mu+.1)>.2:
net.LIF_layers[i].base_layer.bias.data -= .5*(m-mu)
#if hasattr(net.LIF_layers[i],'rec_layer'):
# net.LIF_layers[i].rec_layer.bias.data -= .5*(m-mu)
done=False
else:
done=True
alldone*=done


def init_LSUV_actrate(net, data_batch, act_rate, threshold=0., var=1.0):
from scipy.stats import norm
import scipy.optimize
tgt_mu = scipy.optimize.fmin(lambda loc: (act_rate-(1-norm.cdf(threshold,loc,var)))**2, x0=0.)[0]
init_LSUV(net, data_batch, mu=tgt_mu, var=var)

235 changes: 144 additions & 91 deletions decolle/lenet_decolle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,102 @@ def __init__(self,
alpha=[.9],
beta=[.85],
alpharp=[.65],
wrp=[1.0],
dropout=[0.5],
num_conv_layers=2,
num_mlp_layers=1,
deltat=1000,
lc_ampl=.5,
lc_ampl=[.5],
lif_layer_type = LIFLayer,
method='rtrl',
with_output_layer = False):

self.with_output_layer = with_output_layer
if with_output_layer:
Mhid += [out_channels]
num_mlp_layers += 1
self.num_layers = num_layers = num_conv_layers + num_mlp_layers
self.num_conv_layers = num_conv_layers
self.num_mlp_layers = num_mlp_layers

if Mhid is None:
Mhid = []
if self.with_output_layer:
Mhid += [out_channels]
self.num_layers += 1

# If only one value provided, then it is duplicated for each layer
if len(kernel_size) == 1: kernel_size = kernel_size * num_conv_layers
if stride is None: stride=[1]
if len(stride) == 1: stride = stride * num_conv_layers
if pool_size is None: pool_size = [1]
if len(pool_size) == 1: pool_size = pool_size * num_conv_layers
if len(alpha) == 1: alpha = alpha * num_layers
if len(alpharp) == 1: alpharp = alpharp * num_layers
if len(beta) == 1: beta = beta * num_layers
if self.num_conv_layers>0:
if len(kernel_size) == 1: kernel_size = kernel_size * self.num_conv_layers
if stride is None: stride=[1]
if len(stride) == 1: stride = stride * self.num_conv_layers
if pool_size is None: pool_size = [1]
if len(pool_size) == 1: pool_size = pool_size * self.num_conv_layers

if len(alpha) == 1: alpha = alpha * self.num_layers
self.alpha = alpha
if len(alpharp) == 1: alpharp = alpharp * self.num_layers
self.alpharp = alpharp
if not hasattr(wrp, '__len__'): wrp = [wrp]
if len(wrp) == 1: wrp = wrp * self.num_layers
self.wrp = wrp
if len(beta) == 1: beta = beta * self.num_layers
self.beta = beta

if dropout == [] or dropout is None: dropout = [1.0]
if not hasattr(dropout, '__len__'): dropout = [dropout]
if len(dropout) == 1: self.dropout = dropout = dropout * num_layers
if Nhid is None: self.Nhid = Nhid = []
if Mhid is None: self.Mhid = Mhid = []
if len(dropout) == 1: dropout = dropout * self.num_layers
self.dropout = dropout

if Nhid is None: Nhid = []
self.Nhid = Nhid
if Mhid is None: Mhid = []
self.Mhid = Mhid

super(LenetDECOLLE, self).__init__()
if hasattr(lif_layer_type, '__len__'):
self.lif_layer_type = lif_layer_type
else:
self.lif_layer_type = [lif_layer_type]*len(Nhid) + [lif_layer_type]*len(Mhid)

# Computing padding to preserve feature size
padding = (np.array(kernel_size) - 1) // 2 # TODO try to remove padding
self.deltat = deltat
self.method = method
if lc_ampl is not None:
lc_ampl = [lc_ampl]*self.num_layers
self.lc_ampl = lc_ampl



self.out_channels = out_channels

super(LenetDECOLLE, self).__init__()


# THe following lists need to be nn.ModuleList in order for pytorch to properly load and save the state_dict
self.pool_layers = nn.ModuleList()
self.dropout_layers = nn.ModuleList()
self.input_shape = input_shape
Nhid = [input_shape[0]] + Nhid
self.num_conv_layers = num_conv_layers
self.num_mlp_layers = num_mlp_layers


#Compute number channels for convolutional and feedforward stacks.

feature_height = self.input_shape[1]
feature_width = self.input_shape[2]

if num_conv_layers == 0: #No convolutional layer
mlp_in = int(np.prod(self.input_shape))
else:
self.Nhid = [input_shape[0]] + self.Nhid
conv_stack_output_shape = self.build_conv_stack(self.Nhid, feature_height, feature_width, pool_size, kernel_size, stride, out_channels)
mlp_in = int(np.prod(conv_stack_output_shape))

self.Mhid = [mlp_in] + self.Mhid

mlp_stack_output_shape = self.build_mlp_stack(self.Mhid, out_channels)

if self.with_output_layer:
output_shape = self.build_output_layer(self.Mhid, out_channels)


def build_conv_stack(self, Nhid, feature_height, feature_width, pool_size, kernel_size, stride, out_channels):
output_shape = None
padding = (np.array(kernel_size) - 1) // 2
for i in range(self.num_conv_layers):
feature_height, feature_width = get_output_shape(
[feature_height, feature_width],
Expand All @@ -80,58 +129,95 @@ def __init__(self,
feature_height //= pool_size[i]
feature_width //= pool_size[i]
base_layer = nn.Conv2d(Nhid[i], Nhid[i + 1], kernel_size[i], stride[i], padding[i])
layer = lif_layer_type(base_layer,
alpha=alpha[i],
beta=beta[i],
alpharp=alpharp[i],
deltat=deltat,
do_detach= True if method == 'rtrl' else False)
layer = self.lif_layer_type[i](base_layer,
alpha=self.alpha[i],
beta=self.beta[i],
alpharp=self.alpharp[i],
wrp=self.wrp[i],
deltat=self.deltat,
do_detach= True if self.method == 'rtrl' else False)
pool = nn.MaxPool2d(kernel_size=pool_size[i])
readout = nn.Linear(int(feature_height * feature_width * Nhid[i + 1]), out_channels)
if self.lc_ampl is not None:
readout = nn.Linear(int(feature_height * feature_width * Nhid[i + 1]), out_channels)

# Readout layer has random fixed weights
for param in readout.parameters():
param.requires_grad = False
self.reset_lc_parameters(readout, lc_ampl)
# Readout layer has random fixed weights
for param in readout.parameters():
param.requires_grad = False
self.reset_lc_parameters(readout, self.lc_ampl[i])
else:
readout = nn.Identity()
self.readout_layers.append(readout)

if self.dropout[i] > 0.0:
dropout_layer = nn.Dropout(self.dropout[i])
else:
dropout_layer = nn.Identity()

dropout_layer = nn.Dropout(dropout[i])

self.LIF_layers.append(layer)
self.pool_layers.append(pool)
self.readout_layers.append(readout)
self.dropout_layers.append(dropout_layer)
return (Nhid[-1],feature_height, feature_width)

if num_conv_layers == 0: #No convolutional layer
mlp_in = int(np.prod(self.input_shape))
else:
mlp_in = int(feature_height * feature_width * Nhid[-1])
Mhid = [mlp_in] + Mhid
for i in range(num_mlp_layers):
def build_mlp_stack(self, Mhid, out_channels):
output_shape = None

for i in range(self.num_mlp_layers):
base_layer = nn.Linear(Mhid[i], Mhid[i+1])
layer = lif_layer_type(base_layer,
alpha=alpha[i],
beta=beta[i],
alpharp=alpharp[i],
deltat=deltat,
do_detach= True if method == 'rtrl' else False)

if self.with_output_layer and i+1==num_mlp_layers:
readout = nn.Identity()
dropout_layer = nn.Identity()
else:
layer = self.lif_layer_type[i+self.num_conv_layers](base_layer,
alpha=self.alpha[i],
beta=self.beta[i],
alpharp=self.alpharp[i],
wrp=self.wrp[i],
deltat=self.deltat,
do_detach=True if self.method == 'rtrl' else False)
if self.lc_ampl is not None:
readout = nn.Linear(Mhid[i+1], out_channels)
# Readout layer has random fixed weights
for param in readout.parameters():
param.requires_grad = False
self.reset_lc_parameters(readout, lc_ampl)
dropout_layer = nn.Dropout(dropout[self.num_conv_layers+i])
self.reset_lc_parameters(readout, self.lc_ampl[i])
else:
readout = nn.Identity()

if self.dropout[i] > 0.0:
dropout_layer = nn.Dropout(self.dropout[i])
else:
dropout_layer = nn.Identity()
output_shape = out_channels

self.LIF_layers.append(layer)
self.pool_layers.append(nn.Sequential())
self.readout_layers.append(readout)
self.dropout_layers.append(dropout_layer)
return (output_shape,)

def build_output_layer(self, Mhid, out_channels):
if self.with_output_layer:
i=self.num_mlp_layers
base_layer = nn.Linear(Mhid[i], out_channels)
layer = self.lif_layer_type[-1](base_layer,
alpha=self.alpha[i],
beta=self.beta[i],
alpharp=self.alpharp[i],
wrp=self.wrp[i],
deltat=self.deltat,
do_detach=True if self.method == 'rtrl' else False)
readout = nn.Identity()
if self.dropout[i] > 0.0:
dropout_layer = nn.Dropout(self.dropout[i])
else:
dropout_layer = nn.Identity()

output_shape = out_channels

self.LIF_layers.append(layer)
self.pool_layers.append(nn.Sequential())
self.readout_layers.append(readout)
self.dropout_layers.append(dropout_layer)
return (output_shape,)

def forward(self, input):
def step(self, input, *args, **kwargs):
s_out = []
r_out = []
u_out = []
Expand All @@ -141,51 +227,18 @@ def forward(self, input):
input = input.view(input.size(0), -1)
s, u = lif(input)
u_p = pool(u)
if i+1 == self.num_layers:
if i+1 == self.num_layers and self.with_output_layer:
s_ = sigmoid(u_p)
sd_ = u_p
else:
s_ = lif.sg_function(u_p)
sd_ = do(s_)
sd_ = do(s_)
r_ = ro(sd_.reshape(sd_.size(0), -1))

s_out.append(s_)
r_out.append(r_)
u_out.append(u_p)
input = s_.detach() if lif.do_detach else s_
i+=1

return s_out, r_out, u_out

class TimeWrappedLenetDECOLLE(LenetDECOLLE):
def forward(self, Sin):
t_sample = Sin.shape[1]
out = []
for t in (range(0,t_sample)):
Sin_t = Sin[:,t]
out.append(super().forward(Sin_t))
return out

def init(self, data_batch, burnin):
'''
Necessary to reset the state of the network whenever a new batch is presented
'''
if self.requires_init is False:
return
for l in self.LIF_layers:
l.state = None
with torch.no_grad():
self.forward(data_batch[:, burnin:])

def init_parameters(self, data_batch):
Sin = data_batch[:, :, :, :]
s_out = self.forward(Sin)[0][0]
ins = [self.LIF_layers[0].state.Q]+s_out
for i,l in enumerate(self.LIF_layers):
l.init_parameters(ins[i])



if __name__ == "__main__":
#Test building network
net = LenetDECOLLE(Nhid=[1,8],Mhid=[32,64],out_channels=10, input_shape=[1,28,28])
d = torch.zeros([1,1,28,28])
net(d)
Loading

0 comments on commit 61aca54

Please sign in to comment.