-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
4,450 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,42 @@ | ||
Implementation of "CLIF: Complementary Leaky Integrate-and-Fire Neuron for Spiking Neural Networks" | ||
|
||
Paper: https://arxiv.org/pdf/2402.04663.pdf | ||
# CLIF | ||
|
||
## Dependencies | ||
- Python 3 | ||
- PyTorch, torchvision | ||
- spikingjelly 0.0.0.0.12 | ||
- Python packages: `pip install tqdm progress torchtoolbox thop` | ||
|
||
|
||
## Training | ||
We use single GTX4090 GPU for running all the experiments. Multi-GPU training is not supported in the current codes. | ||
|
||
|
||
### Setup | ||
CIFAR-10, CIFAR-100, Tiny-Imagenet, DVS-CIFAR10, and DVS-Gesture: | ||
|
||
# CIFAR-10 | ||
python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron CLIF | ||
|
||
# CIFAR-100 | ||
python train_BPTT.py -data_dir ./data_dir -dataset cifar100 -model spiking_resnet18 -T_max 200 -epochs 200 -neuron CLIF | ||
|
||
# Tiny-Imagenet | ||
python train_BPTT.py -data_dir ./data_dir -dataset tiny_imagenet -model spiking_vgg13_bn -neuron CLIF | ||
# DVS-CIFAR10 | ||
python train_BPTT.py -data_dir ./data_dir -dataset DVSCIFAR10 -T 10 -drop_rate 0.3 -model spiking_vgg11_bn -lr=0.05 -mse_n_reg -neuron CLIF | ||
|
||
# DVS-Gesture | ||
python train_BPTT.py -data_dir ./data_dir -dataset dvsgesture -model spiking_vgg11_bn -T 20 -b 16 -drop_rate 0.4 -neuron CLIF | ||
|
||
If changing neuron, you can change hyperparameters to ``LIF`` or ``PLIF`` directly after ``-neuron``. | ||
|
||
For example to setup LIF neuron for CIFAR-10 task: | ||
|
||
# LIF neuron for CIFAR-10 | ||
python train_BPTT.py -data_dir ./data_dir -dataset cifar10 -model spiking_resnet18 -amp -T_max 200 -epochs 200 -weight_decay 5e-5 -neuron LIF | ||
|
||
|
||
|
||
## Inference | ||
The inference setup could refer file: ``run_inference_script`` |
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
import torch.nn as nn | ||
from spikingjelly.clock_driven import layer | ||
|
||
__all__ = [ | ||
'PreActResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101', 'spiking_resnet152' | ||
] | ||
|
||
|
||
class PreActBlock(nn.Module): | ||
'''Pre-activation version of the BasicBlock.''' | ||
expansion = 1 | ||
|
||
def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs): | ||
super(PreActBlock, self).__init__() | ||
whether_bias = True | ||
self.bn1 = nn.BatchNorm2d(in_channels) | ||
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=whether_bias) | ||
self.bn2 = nn.BatchNorm2d(out_channels) | ||
|
||
self.dropout = layer.Dropout(dropout) | ||
self.conv2 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=3, stride=1, padding=1, | ||
bias=whether_bias) | ||
|
||
if stride != 1 or in_channels != self.expansion * out_channels: | ||
self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, | ||
padding=0, bias=whether_bias) | ||
else: | ||
self.shortcut = nn.Sequential() | ||
|
||
self.relu1 = neuron(**kwargs) | ||
self.relu2 = neuron(**kwargs) | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.bn1(x)) | ||
out = self.conv1(x) | ||
out = self.conv2(self.dropout(self.relu2(self.bn2(out)))) | ||
out = out + self.shortcut(x) | ||
return out | ||
|
||
|
||
class PreActBottleneck(nn.Module): | ||
'''Pre-activation version of the original Bottleneck module.''' | ||
expansion = 4 | ||
|
||
def __init__(self, in_channels, out_channels, stride, dropout, neuron: callable = None, **kwargs): | ||
super(PreActBottleneck, self).__init__() | ||
whether_bias = True | ||
|
||
self.bn1 = nn.BatchNorm2d(in_channels) | ||
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=whether_bias) | ||
self.bn2 = nn.BatchNorm2d(out_channels) | ||
|
||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=whether_bias) | ||
self.bn3 = nn.BatchNorm2d(out_channels) | ||
self.dropout = layer.Dropout(dropout) | ||
self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=1, stride=1, padding=0, | ||
bias=whether_bias) | ||
|
||
if stride != 1 or in_channels != self.expansion * out_channels: | ||
self.shortcut = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, | ||
padding=0, bias=whether_bias) | ||
else: | ||
self.shortcut = nn.Sequential() | ||
|
||
self.relu1 = neuron(**kwargs) | ||
self.relu2 = neuron(**kwargs) | ||
self.relu3 = neuron(**kwargs) | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.bn1(x)) | ||
|
||
out = self.conv1(x) | ||
out = self.conv2(self.relu2(self.bn2(out))) | ||
out = self.conv3(self.dropout(self.relu3(self.bn3(out)))) | ||
|
||
out = out + self.shortcut(x) | ||
|
||
return out | ||
|
||
|
||
class PreActResNet(nn.Module): | ||
|
||
def __init__(self, block, num_blocks, num_classes, dropout, neuron: callable = None, **kwargs): | ||
super(PreActResNet, self).__init__() | ||
self.num_blocks = num_blocks | ||
|
||
self.data_channels = kwargs.get('c_in', 3) | ||
self.init_channels = 64 | ||
self.conv1 = nn.Conv2d(self.data_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.layer1 = self._make_layer(block, 64, num_blocks[0], 1, dropout, neuron, **kwargs) | ||
self.layer2 = self._make_layer(block, 128, num_blocks[1], 2, dropout, neuron, **kwargs) | ||
self.layer3 = self._make_layer(block, 256, num_blocks[2], 2, dropout, neuron, **kwargs) | ||
self.layer4 = self._make_layer(block, 512, num_blocks[3], 2, dropout, neuron, **kwargs) | ||
|
||
self.bn1 = nn.BatchNorm2d(512 * block.expansion) | ||
self.pool = nn.AvgPool2d(4) | ||
self.flat = nn.Flatten() | ||
self.drop = layer.Dropout(dropout) | ||
self.linear = nn.Linear(512 * block.expansion, num_classes) | ||
|
||
self.relu1 = neuron(**kwargs) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, val=1) | ||
nn.init.zeros_(m.bias) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.zeros_(m.bias) | ||
|
||
def _make_layer(self, block, out_channels, num_blocks, stride, dropout, neuron, **kwargs): | ||
strides = [stride] + [1] * (num_blocks - 1) | ||
layers = [] | ||
for stride in strides: | ||
layers.append(block(self.init_channels, out_channels, stride, dropout, neuron, **kwargs)) | ||
self.init_channels = out_channels * block.expansion | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
out = self.conv1(x) | ||
out = self.layer1(out) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
out = self.pool(self.relu1(self.bn1(out))) | ||
out = self.drop(self.flat(out)) | ||
out = self.linear(out) | ||
return out | ||
|
||
|
||
# class Bottleneck(nn.Module): | ||
# expansion = 4 | ||
# | ||
# def __init__(self, in_planes, planes, stride=1, bn_type='', **kwargs_spikes): | ||
# super(Bottleneck, self).__init__() | ||
# self.kwargs_spikes = kwargs_spikes | ||
# self.nb_steps = kwargs_spikes['nb_steps'] | ||
# self.conv1 = tdLayer(nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), self.nb_steps) | ||
# self.bn1 = warpBN(planes, bn_type, self.nb_steps) | ||
# self.spike1 = LIFLayer(**kwargs_spikes) | ||
# self.conv2 = tdLayer(nn.Conv2d(planes, planes, kernel_size=3, | ||
# stride=stride, padding=1, bias=False), self.nb_steps) | ||
# self.bn2 = warpBN(planes, bn_type, self.nb_steps) | ||
# self.spike2 = LIFLayer(**kwargs_spikes) | ||
# self.conv3 = tdLayer(nn.Conv2d(planes, self.expansion * | ||
# planes, kernel_size=1, bias=False), self.nb_steps) | ||
# self.bn3 = warpBN(self.expansion * | ||
# planes, bn_type, self.nb_steps) | ||
# | ||
# self.shortcut = nn.Sequential() | ||
# if stride != 1 or in_planes != self.expansion * planes: | ||
# self.shortcut = nn.Sequential( | ||
# tdLayer(nn.Conv2d(in_planes, self.expansion * planes, | ||
# kernel_size=1, stride=stride, bias=False), self.nb_steps), | ||
# warpBN(self.expansion * planes, bn_type, self.nb_steps) | ||
# ) | ||
# self.spike3 = LIFLayer(**kwargs_spikes) | ||
# | ||
# def forward(self, x): | ||
# out = self.spike1(self.bn1(self.conv1(x))) | ||
# out = self.spike2(self.bn2(self.conv2(out))) | ||
# out = self.bn3(self.conv3(out)) | ||
# out += self.shortcut(x) | ||
# out = self.spike3(out) | ||
# return out | ||
# | ||
# | ||
# class ResNet19(nn.Module): | ||
# def __init__(self, block, num_block_layers, num_classes=10, in_channel=3, bn_type='', **kwargs_spikes): | ||
# super(ResNet19, self).__init__() | ||
# self.in_planes = 128 | ||
# self.bn_type = bn_type | ||
# self.kwargs_spikes = kwargs_spikes | ||
# self.nb_steps = kwargs_spikes['nb_steps'] | ||
# self.conv0 = nn.Sequential( | ||
# tdLayer(nn.Conv2d(in_channel, self.in_planes, kernel_size=3, padding=1, stride=1, bias=False), | ||
# nb_steps=self.nb_steps), | ||
# warpBN(self.in_planes, bn_type, self.nb_steps), | ||
# LIFLayer(**kwargs_spikes) | ||
# ) | ||
# self.layer1 = self._make_layer(block, 128, num_block_layers[0], stride=1) | ||
# self.layer2 = self._make_layer(block, 256, num_block_layers[1], stride=2) | ||
# self.layer3 = self._make_layer(block, 512, num_block_layers[2], stride=2) | ||
# self.avg_pool = tdLayer(nn.AdaptiveAvgPool2d((1, 1)), nb_steps=self.nb_steps) | ||
# self.classifier = nn.Sequential( | ||
# tdLayer(nn.Linear(512 * block.expansion, 256, bias=False), nb_steps=self.nb_steps), | ||
# LIFLayer(**kwargs_spikes), | ||
# tdLayer(nn.Linear(256, num_classes, bias=False), nb_steps=self.nb_steps), | ||
# Readout() | ||
# ) | ||
# | ||
# def _make_layer(self, block, planes, num_blocks, stride): | ||
# strides = [stride] + [1] * (num_blocks - 1) | ||
# layers = [] | ||
# for stride in strides: | ||
# layers.append(block(self.in_planes, planes, stride, self.bn_type, **self.kwargs_spikes)) | ||
# self.in_planes = planes * block.expansion | ||
# return nn.Sequential(*layers) | ||
# | ||
# def forward(self, x): | ||
# out, _ = torch.broadcast_tensors(x, torch.zeros((self.nb_steps,) + x.shape)) | ||
# out = self.conv0(out) | ||
# out = self.layer1(out) | ||
# out = self.layer2(out) | ||
# out = self.layer3(out) | ||
# out = self.avg_pool(out) | ||
# out = out.view(out.shape[0], out.shape[1], -1) | ||
# out = self.classifier(out) | ||
# return out | ||
|
||
def spiking_resnet18(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): | ||
return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes, neuron_dropout, neuron=neuron, **kwargs) | ||
|
||
|
||
def spiking_resnet34(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): | ||
return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) | ||
|
||
|
||
def spiking_resnet50(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): | ||
return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) | ||
|
||
|
||
def spiking_resnet101(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): | ||
return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) | ||
|
||
|
||
def spiking_resnet152(neuron: callable = None, num_classes=10, neuron_dropout=0, **kwargs): | ||
return PreActResNet(PreActBottleneck, [3, 8, 36, 3], num_classes, neuron_dropout, neuron=neuron, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch.nn as nn | ||
from spikingjelly.clock_driven import layer | ||
|
||
__all__ = [ | ||
'SpikingVGGBN', 'spiking_vgg11_bn', 'spiking_vgg13_bn', 'spiking_vgg16_bn', 'spiking_vgg19_bn' | ||
] | ||
|
||
cfg = { | ||
|
||
'VGG11': [ | ||
[64, 'M'], | ||
[128, 'M'], | ||
[256, 256, 'M'], | ||
[512, 512, 'M'], | ||
[512, 512, 'M'] | ||
], | ||
'VGG13': [ | ||
[64, 64, 'M'], | ||
[128, 128, 'M'], | ||
[256, 256, 'M'], | ||
[512, 512, 'M'], | ||
[512, 512, 'M'] | ||
], | ||
'VGG16': [ | ||
[64, 64, 'M'], | ||
[128, 128, 'M'], | ||
[256, 256, 256, 'M'], | ||
[512, 512, 512, 'M'], | ||
[512, 512, 512, 'M'] | ||
], | ||
'VGG19': [ | ||
[64, 64, 'M'], | ||
[128, 128, 'M'], | ||
[256, 256, 256, 256, 'M'], | ||
[512, 512, 512, 512, 'M'], | ||
[512, 512, 512, 512, 'M'] | ||
] | ||
} | ||
|
||
|
||
class SpikingVGGBN(nn.Module): | ||
def __init__(self, vgg_name, neuron: callable = None, dropout=0.0, num_classes=10, **kwargs): | ||
super(SpikingVGGBN, self).__init__() | ||
self.whether_bias = True | ||
self.init_channels = kwargs.get('c_in', 2) | ||
|
||
self.layer1 = self._make_layers(cfg[vgg_name][0], dropout, neuron, **kwargs) | ||
self.layer2 = self._make_layers(cfg[vgg_name][1], dropout, neuron, **kwargs) | ||
self.layer3 = self._make_layers(cfg[vgg_name][2], dropout, neuron, **kwargs) | ||
self.layer4 = self._make_layers(cfg[vgg_name][3], dropout, neuron, **kwargs) | ||
self.layer5 = self._make_layers(cfg[vgg_name][4], dropout, neuron, **kwargs) | ||
|
||
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) | ||
|
||
self.classifier = nn.Sequential( | ||
nn.Flatten(), | ||
nn.Linear(512 * 7 * 7, num_classes), | ||
) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.normal_(m.weight, 0, 0.01) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def _make_layers(self, cfg, dropout, neuron, **kwargs): | ||
layers = [] | ||
for x in cfg: | ||
if x == 'M': | ||
layers.append(nn.AvgPool2d(kernel_size=2, stride=2)) | ||
else: | ||
layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.whether_bias)) | ||
layers.append(nn.BatchNorm2d(x)) | ||
# kwargs["l_i"] += 1 | ||
layers.append(neuron(**kwargs)) | ||
layers.append(layer.Dropout(dropout)) | ||
self.init_channels = x | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
out = self.layer1(x) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
out = self.layer5(out) | ||
out = self.avgpool(out) | ||
out = self.classifier(out) | ||
|
||
return out | ||
|
||
|
||
def spiking_vgg9_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): | ||
return SpikingVGGBN('VGG9', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) | ||
|
||
|
||
def spiking_vgg11_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): | ||
return SpikingVGGBN('VGG11', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) | ||
|
||
|
||
def spiking_vgg13_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): | ||
return SpikingVGGBN('VGG13', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) | ||
|
||
|
||
def spiking_vgg16_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): | ||
return SpikingVGGBN('VGG16', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) | ||
|
||
|
||
def spiking_vgg19_bn(neuron: callable = None, num_classes=10, neuron_dropout=0.0, **kwargs): | ||
return SpikingVGGBN('VGG19', neuron=neuron, dropout=neuron_dropout, num_classes=num_classes, **kwargs) |
Oops, something went wrong.