-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist.py
40 lines (32 loc) · 1.57 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# This file starts the training on the MNIST dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from utils import AddUniformNoise #, AddGaussianNoise
from torch.cuda import empty_cache
from torch.utils.data import DataLoader # Dataset
# from torchvision import transforms
import config as c
import arguments as a
from train import train
from torch.utils.data.sampler import SubsetRandomSampler # RandomSampling
if c.gpu:
empty_cache() # free up memory for cuda
def mnist(load_only=False):
mnist_pre = Compose([
ToTensor(),
AddUniformNoise(),
Resize((c.img_size[0], c.img_size[0])),
Normalize((0.1307,), (0.3081,))
])
mnist_train = MNIST(root='./data', train=True, download=True, transform=mnist_pre)
mnist_test = MNIST(root='./data', train=False, download=True, transform=mnist_pre)
if c.test_run:
toy_sampler = SubsetRandomSampler(range(200))
else:
toy_sampler = None
train_loader = DataLoader(mnist_train,batch_size = a.args.batch_size,pin_memory=True,
shuffle=False,sampler=toy_sampler)
val_loader = DataLoader(mnist_test,batch_size = a.args.batch_size,pin_memory=True,
shuffle=False,sampler=toy_sampler)
mdl = train(train_loader,val_loader)
return mdl, train_loader, val_loader