Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HuuYuLong authored May 6, 2024
1 parent ed19c47 commit fb58e65
Show file tree
Hide file tree
Showing 20 changed files with 4,450 additions and 12 deletions.
45 changes: 42 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
Implementation of "CLIF: Complementary Leaky Integrate-and-Fire Neuron for Spiking Neural Networks"

Paper: https://arxiv.org/pdf/2402.04663.pdf
# CLIF

## Dependencies
- Python 3
- PyTorch, torchvision
- spikingjelly 0.0.0.0.12
- Python packages: `pip install tqdm progress torchtoolbox thop`


## Training
We use single GTX4090 GPU for running all the experiments. Multi-GPU training is not supported in the current codes.


### Setup
CIFAR-10, CIFAR-100, Tiny-Imagenet, DVS-CIFAR10, and DVS-Gesture:

# CIFAR-10
python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron CLIF

# CIFAR-100
python train_BPTT.py -data_dir ./data_dir -dataset cifar100 -model spiking_resnet18 -T_max 200 -epochs 200 -neuron CLIF

# Tiny-Imagenet
python train_BPTT.py -data_dir ./data_dir -dataset tiny_imagenet -model spiking_vgg13_bn -neuron CLIF
# DVS-CIFAR10
python train_BPTT.py -data_dir ./data_dir -dataset DVSCIFAR10 -T 10 -drop_rate 0.3 -model spiking_vgg11_bn -lr=0.05 -mse_n_reg -neuron CLIF

# DVS-Gesture
python train_BPTT.py -data_dir ./data_dir -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -drop_rate 0.4 -neuron CLIF

If changing neuron, you can change hyperparameters to ``LIF`` or ``PLIF`` directly after ``-neuron``.

For example to setup LIF neuron for CIFAR-10 task:

# LIF neuron for CIFAR-10
python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron LIF



## Inference
The inference setup could refer file: ``run_inference_script``
538 changes: 538 additions & 0 deletions inference.py

Large diffs are not rendered by default.

Empty file added models/__init__.py
Empty file.
231 changes: 231 additions & 0 deletions models/spiking_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import torch.nn as nn
from spikingjelly.clock_driven import layer

__all__ = [
'PreActResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101', 'spiking_resnet152'
]


class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1

def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs):
super(PreActBlock, self).__init__()
whether_bias = True
self.bn1 = nn.BatchNorm2d(in_channels)

self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=whether_bias)
self.bn2 = nn.BatchNorm2d(out_channels)

self.dropout = layer.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=3, stride=1, padding=1,
bias=whether_bias)

if stride != 1 or in_channels != self.expansion * out_channels:
self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride,
padding=0, bias=whether_bias)
else:
self.shortcut = nn.Sequential()

self.relu1 = neuron(**kwargs)
self.relu2 = neuron(**kwargs)

def forward(self, x):
x = self.relu1(self.bn1(x))
out = self.conv1(x)
out = self.conv2(self.dropout(self.relu2(self.bn2(out))))
out = out + self.shortcut(x)
return out


class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4

def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs):
super(PreActBottleneck, self).__init__()
whether_bias = True

self.bn1 = nn.BatchNorm2d(in_channels)

self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=whether_bias)
self.bn2 = nn.BatchNorm2d(out_channels)

self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=whether_bias)
self.bn3 = nn.BatchNorm2d(out_channels)
self.dropout = layer.Dropout(dropout)
self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=1, stride=1, padding=0,
bias=whether_bias)

if stride != 1 or in_channels != self.expansion * out_channels:
self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride,
padding=0, bias=whether_bias)
else:
self.shortcut = nn.Sequential()

self.relu1 = neuron(**kwargs)
self.relu2 = neuron(**kwargs)
self.relu3 = neuron(**kwargs)

def forward(self, x):
x = self.relu1(self.bn1(x))

out = self.conv1(x)
out = self.conv2(self.relu2(self.bn2(out)))
out = self.conv3(self.dropout(self.relu3(self.bn3(out))))

out = out + self.shortcut(x)

return out


class PreActResNet(nn.Module):

def __init__(self, block, num_blocks, num_classes, dropout, neuron: callable = None, **kwargs):
super(PreActResNet, self).__init__()
self.num_blocks = num_blocks

self.data_channels = kwargs.get('c_in', 3)
self.init_channels = 64
self.conv1 = nn.Conv2d(self.data_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1 = self._make_layer(block, 64, num_blocks[0], 1, dropout, neuron, **kwargs)
self.layer2 = self._make_layer(block, 128, num_blocks[1], 2, dropout, neuron, **kwargs)
self.layer3 = self._make_layer(block, 256, num_blocks[2], 2, dropout, neuron, **kwargs)
self.layer4 = self._make_layer(block, 512, num_blocks[3], 2, dropout, neuron, **kwargs)

self.bn1 = nn.BatchNorm2d(512 * block.expansion)
self.pool = nn.AvgPool2d(4)
self.flat = nn.Flatten()
self.drop = layer.Dropout(dropout)
self.linear = nn.Linear(512 * block.expansion, num_classes)

self.relu1 = neuron(**kwargs)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, val=1)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)

def _make_layer(self, block, out_channels, num_blocks, stride, dropout, neuron, **kwargs):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.init_channels, out_channels, stride, dropout, neuron, **kwargs))
self.init_channels = out_channels * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.pool(self.relu1(self.bn1(out)))
out = self.drop(self.flat(out))
out = self.linear(out)
return out


# class Bottleneck(nn.Module):
# expansion = 4
#
# def __init__(self, in_planes, planes, stride=1, bn_type='', **kwargs_spikes):
# super(Bottleneck, self).__init__()
# self.kwargs_spikes = kwargs_spikes
# self.nb_steps = kwargs_spikes['nb_steps']
# self.conv1 = tdLayer(nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), self.nb_steps)
# self.bn1 = warpBN(planes, bn_type, self.nb_steps)
# self.spike1 = LIFLayer(**kwargs_spikes)
# self.conv2 = tdLayer(nn.Conv2d(planes, planes, kernel_size=3,
# stride=stride, padding=1, bias=False), self.nb_steps)
# self.bn2 = warpBN(planes, bn_type, self.nb_steps)
# self.spike2 = LIFLayer(**kwargs_spikes)
# self.conv3 = tdLayer(nn.Conv2d(planes, self.expansion *
# planes, kernel_size=1, bias=False), self.nb_steps)
# self.bn3 = warpBN(self.expansion *
# planes, bn_type, self.nb_steps)
#
# self.shortcut = nn.Sequential()
# if stride != 1 or in_planes != self.expansion * planes:
# self.shortcut = nn.Sequential(
# tdLayer(nn.Conv2d(in_planes, self.expansion * planes,
# kernel_size=1, stride=stride, bias=False), self.nb_steps),
# warpBN(self.expansion * planes, bn_type, self.nb_steps)
# )
# self.spike3 = LIFLayer(**kwargs_spikes)
#
# def forward(self, x):
# out = self.spike1(self.bn1(self.conv1(x)))
# out = self.spike2(self.bn2(self.conv2(out)))
# out = self.bn3(self.conv3(out))
# out += self.shortcut(x)
# out = self.spike3(out)
# return out
#
#
# class ResNet19(nn.Module):
# def __init__(self, block, num_block_layers, num_classes=10, in_channel=3, bn_type='', **kwargs_spikes):
# super(ResNet19, self).__init__()
# self.in_planes = 128
# self.bn_type = bn_type
# self.kwargs_spikes = kwargs_spikes
# self.nb_steps = kwargs_spikes['nb_steps']
# self.conv0 = nn.Sequential(
# tdLayer(nn.Conv2d(in_channel, self.in_planes, kernel_size=3, padding=1, stride=1, bias=False),
# nb_steps=self.nb_steps),
# warpBN(self.in_planes, bn_type, self.nb_steps),
# LIFLayer(**kwargs_spikes)
# )
# self.layer1 = self._make_layer(block, 128, num_block_layers[0], stride=1)
# self.layer2 = self._make_layer(block, 256, num_block_layers[1], stride=2)
# self.layer3 = self._make_layer(block, 512, num_block_layers[2], stride=2)
# self.avg_pool = tdLayer(nn.AdaptiveAvgPool2d((1, 1)), nb_steps=self.nb_steps)
# self.classifier = nn.Sequential(
# tdLayer(nn.Linear(512 * block.expansion, 256, bias=False), nb_steps=self.nb_steps),
# LIFLayer(**kwargs_spikes),
# tdLayer(nn.Linear(256, num_classes, bias=False), nb_steps=self.nb_steps),
# Readout()
# )
#
# def _make_layer(self, block, planes, num_blocks, stride):
# strides = [stride] + [1] * (num_blocks - 1)
# layers = []
# for stride in strides:
# layers.append(block(self.in_planes, planes, stride, self.bn_type, **self.kwargs_spikes))
# self.in_planes = planes * block.expansion
# return nn.Sequential(*layers)
#
# def forward(self, x):
# out, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape))
# out = self.conv0(out)
# out = self.layer1(out)
# out = self.layer2(out)
# out = self.layer3(out)
# out = self.avg_pool(out)
# out = out.view(out.shape[0], out.shape[1], -1)
# out = self.classifier(out)
# return out

def spiking_resnet18(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs):
return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes, neuron_dropout, neuron=neuron, **kwargs)


def spiking_resnet34(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs):
return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs)


def spiking_resnet50(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs):
return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs)


def spiking_resnet101(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs):
return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs)


def spiking_resnet152(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs):
return PreActResNet(PreActBottleneck, [3, 8, 36, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs)
115 changes: 115 additions & 0 deletions models/spiking_vgg_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch.nn as nn
from spikingjelly.clock_driven import layer

__all__ = [
'SpikingVGGBN', 'spiking_vgg11_bn', 'spiking_vgg13_bn', 'spiking_vgg16_bn', 'spiking_vgg19_bn'
]

cfg = {

'VGG11': [
[64, 'M'],
[128, 'M'],
[256, 256, 'M'],
[512, 512, 'M'],
[512, 512, 'M']
],
'VGG13': [
[64, 64, 'M'],
[128, 128, 'M'],
[256, 256, 'M'],
[512, 512, 'M'],
[512, 512, 'M']
],
'VGG16': [
[64, 64, 'M'],
[128, 128, 'M'],
[256, 256, 256, 'M'],
[512, 512, 512, 'M'],
[512, 512, 512, 'M']
],
'VGG19': [
[64, 64, 'M'],
[128, 128, 'M'],
[256, 256, 256, 256, 'M'],
[512, 512, 512, 512, 'M'],
[512, 512, 512, 512, 'M']
]
}


class SpikingVGGBN(nn.Module):
def __init__(self, vgg_name, neuron: callable = None, dropout=0.0, num_classes=10, **kwargs):
super(SpikingVGGBN, self).__init__()
self.whether_bias = True
self.init_channels = kwargs.get('c_in', 2)

self.layer1 = self._make_layers(cfg[vgg_name][0], dropout, neuron, **kwargs)
self.layer2 = self._make_layers(cfg[vgg_name][1], dropout, neuron, **kwargs)
self.layer3 = self._make_layers(cfg[vgg_name][2], dropout, neuron, **kwargs)
self.layer4 = self._make_layers(cfg[vgg_name][3], dropout, neuron, **kwargs)
self.layer5 = self._make_layers(cfg[vgg_name][4], dropout, neuron, **kwargs)

self.avgpool = nn.AdaptiveAvgPool2d((7, 7))

self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 7 * 7, num_classes),
)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def _make_layers(self, cfg, dropout, neuron, **kwargs):
layers = []
for x in cfg:
if x == 'M':
layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
else:
layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.whether_bias))
layers.append(nn.BatchNorm2d(x))
# kwargs["l_i"] += 1
layers.append(neuron(**kwargs))
layers.append(layer.Dropout(dropout))
self.init_channels = x
return nn.Sequential(*layers)

def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.avgpool(out)
out = self.classifier(out)

return out


def spiking_vgg9_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs):
return SpikingVGGBN('VGG9', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs)


def spiking_vgg11_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs):
return SpikingVGGBN('VGG11', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs)


def spiking_vgg13_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs):
return SpikingVGGBN('VGG13', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs)


def spiking_vgg16_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs):
return SpikingVGGBN('VGG16', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs)


def spiking_vgg19_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs):
return SpikingVGGBN('VGG19', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs)
Loading

0 comments on commit fb58e65

Please sign in to comment.