diff --git a/mnist/main.py b/mnist/main.py index 0a733cd48b..04dd254117 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -100,12 +100,14 @@ def main(): device = torch.device("cuda" if use_cuda else "cpu") - kwargs = {'batch_size': args.batch_size} + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} if use_cuda: - kwargs.update({'num_workers': 1, + cuda_kwargs = {'num_workers': 1, 'pin_memory': True, - 'shuffle': True}, - ) + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) transform=transforms.Compose([ transforms.ToTensor(), @@ -115,8 +117,8 @@ def main(): transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr)