forked from antoine311200/Spiking-Neural-Networks-Benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
38 lines (29 loc) · 1.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import SHD_dataloaders, SSC_dataloaders, GSC_dataloaders
from config import Config
from snn_delays import SnnDelays
import torch
from snn import SNN
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n=====> Device = {device} \n\n")
config = Config()
if config.model_type == 'snn':
model = SNN(config).to(device)
else:
model = SnnDelays(config).to(device)
if config.model_type == 'snn_delays_lr0':
model.round_pos()
print(f"===> Dataset = {config.dataset}")
print(f"===> Model type = {config.model_type}")
print(f"===> Model size = {utils.count_parameters(model)}\n\n")
if config.dataset == 'shd':
train_loader, valid_loader = SHD_dataloaders(config)
test_loader = None
elif config.dataset == 'ssc':
train_loader, valid_loader, test_loader = SSC_dataloaders(config)
elif config.dataset == 'gsc':
train_loader, valid_loader, test_loader = GSC_dataloaders(config)
else:
raise Exception(f'dataset {config.dataset} not implemented')
if __name__ == '__main__':
model.train_model(train_loader, valid_loader, test_loader, device)