-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
32 lines (25 loc) · 1.34 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from scipy.io import loadmat
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import snntorch
from snntorch.spikevision import spikedata
import numpy as np
import torch
import os
from torch.utils.data import Dataset
from torchvision import transforms, utils
def load_data(args):
if args.dataset == "N-MNIST":
train_ds = spikedata.NMNIST("data/nmnist", train=True, num_steps=args.steps, dt=int(300000/args.steps))
test_ds = spikedata.NMNIST("data/nmnist", train=False, num_steps=args.steps, dt=int(300000/args.steps))
train_dl = DataLoader(train_ds, shuffle=True, batch_size=args.train_batch_size, pin_memory=True, num_workers=8)
test_dl = DataLoader(test_ds, shuffle=False, batch_size=args.eval_batch_size, pin_memory=True, num_workers=8)
if args.dataset == "SHD":
train_ds = spikedata.SHD("data/shd", train=True, num_steps=args.steps, dt=int(800*1000/args.steps))
test_ds = spikedata.SHD("data/shd", train=False, num_steps=args.steps, dt=int(800*1000/args.steps))
train_dl = DataLoader(train_ds, shuffle=True, batch_size=args.train_batch_size, pin_memory=True)
test_dl = DataLoader(test_ds, shuffle=False, batch_size=args.eval_batch_size, pin_memory=True)
return train_dl, test_dl