Skip to content

Commit

Permalink
Fix unused --test-batch-size command line argument
Browse files Browse the repository at this point in the history
  • Loading branch information
pbelevich authored and soumith committed Oct 10, 2020
1 parent 599654a commit 8d9f910
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,14 @@ def main():

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'batch_size': args.batch_size}
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
kwargs.update({'num_workers': 1,
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True},
)
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
transforms.ToTensor(),
Expand All @@ -115,8 +117,8 @@ def main():
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
Expand Down

0 comments on commit 8d9f910

Please sign in to comment.