-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_doublenmnist.py
135 lines (114 loc) · 4.91 KB
/
create_doublenmnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/bin/python
#-----------------------------------------------------------------------------
# File Name : test_doublenmnist.py
# Author: Emre Neftci
#
# Creation Date : Thu Nov 7 20:30:14 2019
# Last Modified :
#
# Copyright : (c) UC Regents, Emre Neftci
# Licence : GPLv2
#-----------------------------------------------------------------------------
from torchneuromorphic.doublenmnist.doublenmnist_dataloaders import *
def create_datasets(
root = 'data/nmnist/n_mnist.hdf5',
batch_size = 72 ,
chunk_size_train = 300,
chunk_size_test = 300,
ds = 1,
dt = 1000,
transform_train = None,
transform_test = None,
target_transform_train = None,
target_transform_test = None,
nclasses = 5,
samples_per_class = 2,
classes_meta = np.arange(100, dtype='int')):
size = [2, 32//ds, 32//ds]
if transform_train is None:
transform_train = Compose([
CropDims(low_crop=[0,0], high_crop=[32,32], dims=[2,3]),
Downsample(factor=[dt,1,ds,ds]),
ToEventSum(T = chunk_size_train, size = size),
ToTensor()])
if transform_test is None:
transform_test = Compose([
CropDims(low_crop=[0,0], high_crop=[32,32], dims=[2,3]),
Downsample(factor=[dt,1,ds,ds]),
ToEventSum(T = chunk_size_test, size = size),
ToTensor()])
if target_transform_train is None:
target_transform_train =Compose([Repeat(chunk_size_train)])
if target_transform_test is None:
target_transform_test = Compose([Repeat(chunk_size_test)])
labels_u = np.random.choice(classes_meta, nclasses, replace = False ) #100 here becuase we have two pairs of digits between 0 and 9
train_ds = DoubleNMNISTDataset(root,train=True,
transform = transform_train,
target_transform = target_transform_train,
chunk_size = chunk_size_train,
nclasses = nclasses,
samples_per_class = samples_per_class,
labels_u = labels_u)
test_ds = DoubleNMNISTDataset(root, transform = transform_test,
target_transform = target_transform_test,
train=False,
chunk_size = chunk_size_test,
nclasses = nclasses,
samples_per_class = samples_per_class,
labels_u = labels_u)
return train_ds, test_ds
def create_dataloader(
root = 'data/nmnist/n_mnist.hdf5',
batch_size = 72 ,
chunk_size_train = 300,
chunk_size_test = 300,
ds = 1,
dt = 1000,
transform_train = None,
transform_test = None,
target_transform_train = None,
target_transform_test = None,
nclasses = 5,
samples_per_class = 2,
classes_meta = np.arange(100, dtype='int'),
**dl_kwargs):
train_d, test_d = create_datasets(
root = 'data/nmnist/n_mnist.hdf5',
batch_size = batch_size,
chunk_size_train = chunk_size_train,
chunk_size_test = chunk_size_test,
ds = ds,
dt = dt,
transform_train = transform_train,
transform_test = transform_test,
target_transform_train = target_transform_train,
target_transform_test = target_transform_test,
classes_meta = classes_meta,
nclasses = nclasses,
samples_per_class = samples_per_class)
train_dl = torch.utils.data.DataLoader(train_d, shuffle=True, batch_size=batch_size, **dl_kwargs)
test_dl = torch.utils.data.DataLoader(test_d, shuffle=False, batch_size=batch_size, **dl_kwargs)
return train_dl, test_dl
def sample_double_mnist_task( N = 5,
K = 2,
meta_split = [range(64), range(64,80), range(80,100)],
meta_dataset_type = 'train',
**kwargs):
classes_meta = {}
classes_meta['train'] = np.array(meta_split[0], dtype='int')
classes_meta['val'] = np.array(meta_split[1], dtype='int')
classes_meta['test'] = np.array(meta_split[2], dtype='int')
assert meta_dataset_type in ['train', 'val', 'test']
return create_dataloader(classes_meta = classes_meta[meta_dataset_type], nclasses= N, samples_per_class = K, **kwargs)
def create_doublenmnist(batch_size = 50, shots=10, ways=5, mtype='train', **kwargs):
train_dl, test_dl = sample_double_mnist_task(
meta_dataset_type = mtype,
N = ways,
K = shots,
root='data/nmnist/n_mnist.hdf5',
batch_size=batch_size,
num_workers=4,
**kwargs)
return train_dl, test_dl
if __name__ == '__main__':
trdl, tedl = create_doublenmnist(ds=2)