This is the PyTorch implementation of paper: Online Training Through Time for Spiking Neural Networks (NeurIPS 2022). [arxiv][openreview].
Update 2023/12: Some modules of OTTT have been integrated into the latest code of spikingjelly, and the new codes for the neuron model can support multi-gpu training. We provide the reference codes included in the spikingjelly repository in spikingjelly_codes/reference_codes/, where neuron.py, layer.py, and functional.py (located in spikingjelly/activation_based/ in their repository) include some modules for OTTT, and spiking_vggws_ottt.py (located in spikingjelly/activation_based/model/ in their repository) gives an example of how to define the model with OTTT modules. We also provide an example of how to train the model in spikingjelly_codes/train_ottt_cifar.py for reference.
- Python 3 (Recommend to use Anaconda)
- PyTorch, torchvision
- NVIDIA GPU + CUDA
- Python packages:
pip install numpy opencv-python
For OTTT$_A$, run as following:
python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0
# For VGG-F model
python train_cifar.py -data_dir path_to_data_dir -dataset cifar100 -out_dir log_checkpoint_name -gpu-id 0 -model online_spiking_vgg11f_ws
python train_cifar10dvs.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0
python train_imagenet.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0
For OTTT$_O$, add the argument -online_update as:
python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0 -online_update
The default hyperparameters in the code are the same as in the paper.
Note: Current codes only support single-gpu training.
We provide the example code to calculate the firing rate statistics during evaluation. Run as following:
python get_rate_cifar.py -data_dir path_to_data_dir -dataset cifar10 -gpu-id 0 -resume path_to_checkpoint
python get_rate_imagenet.py -data_dir path_to_data_dir -gpu-id 0 -resume path_to_checkpoint
Some pretrained models can be downloaded from Google Drive or Baidu Drive (extraction code: gppq).
Some codes for the neuron model and data prepoccessing are adapted from the spikingjelly repository, and the codes for some utils are from the pytorch-classification repository.
If you have any questions, please contact [email protected].