From 4a5e036bf586f34ea7a3b8e3aa0cc66bb19e79f9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 13 Jun 2019 16:16:51 +0200 Subject: [PATCH 1/2] added a generic test for the datasets --- test/fakedata_generation.py | 2 +- test/test_datasets.py | 62 ++++++++++++------------------------- 2 files changed, 21 insertions(+), 43 deletions(-) diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index b0b0aec7cbf..7c871054e4e 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -29,7 +29,7 @@ def _make_image_file(filename, num_images): f.write(img.numpy().tobytes()) def _make_label_file(filename, num_images): - labels = torch.randint(0, 10, size=(num_images,), dtype=torch.uint8) + labels = torch.zeros((num_images,), dtype=torch.uint8) with open(filename, "wb") as f: f.write(_encode(2049)) # magic header f.write(_encode(num_images)) diff --git a/test/test_datasets.py b/test/test_datasets.py index 19f2428db6a..ce332270da3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -9,6 +9,14 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root +def generic_dataset_test(tester, dataset, num_images=1, cls='fakedata'): + tester.assertEqual(len(dataset), num_images) + img, target = dataset[0] + tester.assertTrue(isinstance(img, PIL.Image.Image)) + tester.assertTrue(isinstance(target, int)) + tester.assertEqual(dataset.class_to_idx[cls], target) + + class Tester(unittest.TestCase): def test_imagefolder(self): # TODO: create the fake data on-the-fly @@ -64,47 +72,33 @@ def test_mnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) - self.assertEqual(len(dataset), num_examples) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + generic_dataset_test(self, dataset, num_images=num_examples, + cls=dataset.classes[0]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_kmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) - img, target = dataset[0] - self.assertEqual(len(dataset), num_examples) - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + generic_dataset_test(self, dataset, num_images=num_examples, + cls=dataset.classes[0]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_fashionmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) - img, target = dataset[0] - self.assertEqual(len(dataset), num_examples) - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + generic_dataset_test(self, dataset, num_images=num_examples, + cls=dataset.classes[0]) @mock.patch('torchvision.datasets.utils.download_url') def test_imagenet(self, mock_download): with imagenet_root() as root: dataset = torchvision.datasets.ImageNet(root, split='train', download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset) dataset = torchvision.datasets.ImageNet(root, split='val', download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -113,18 +107,10 @@ def test_cifar10(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR10') as root: dataset = torchvision.datasets.CIFAR10(root, train=True, download=True) - self.assertEqual(len(dataset), 5) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset, num_images=5) dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -133,18 +119,10 @@ def test_cifar100(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR100') as root: dataset = torchvision.datasets.CIFAR100(root, train=True, download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset) dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + generic_dataset_test(self, dataset) if __name__ == '__main__': From 20e69223996e59ac05cf3d7f8974d6d96b75612b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 15 Jun 2019 11:10:54 +0200 Subject: [PATCH 2/2] addressed requested changes - renamed generic*() to generic_classification*() - moved function inside Tester - test class_to_idx attribute outside of generic_classification*() --- test/test_datasets.py | 49 +++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index ce332270da3..861c5dc0f6c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -9,15 +9,13 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root -def generic_dataset_test(tester, dataset, num_images=1, cls='fakedata'): - tester.assertEqual(len(dataset), num_images) - img, target = dataset[0] - tester.assertTrue(isinstance(img, PIL.Image.Image)) - tester.assertTrue(isinstance(target, int)) - tester.assertEqual(dataset.class_to_idx[cls], target) - - class Tester(unittest.TestCase): + def generic_classification_dataset_test(self, dataset, num_images=1): + self.assertEqual(len(dataset), num_images) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + def test_imagefolder(self): # TODO: create the fake data on-the-fly FAKEDATA_DIR = get_file_path_2( @@ -72,33 +70,36 @@ def test_mnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) - generic_dataset_test(self, dataset, num_images=num_examples, - cls=dataset.classes[0]) + self.generic_classification_dataset_test(dataset, num_images=num_examples) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_kmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) - generic_dataset_test(self, dataset, num_images=num_examples, - cls=dataset.classes[0]) + self.generic_classification_dataset_test(dataset, num_images=num_examples) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_fashionmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) - generic_dataset_test(self, dataset, num_images=num_examples, - cls=dataset.classes[0]) + self.generic_classification_dataset_test(dataset, num_images=num_examples) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.utils.download_url') def test_imagenet(self, mock_download): with imagenet_root() as root: dataset = torchvision.datasets.ImageNet(root, split='train', download=True) - generic_dataset_test(self, dataset) + self.generic_classification_dataset_test(dataset) dataset = torchvision.datasets.ImageNet(root, split='val', download=True) - generic_dataset_test(self, dataset) + self.generic_classification_dataset_test(dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -107,10 +108,14 @@ def test_cifar10(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR10') as root: dataset = torchvision.datasets.CIFAR10(root, train=True, download=True) - generic_dataset_test(self, dataset, num_images=5) + self.generic_classification_dataset_test(dataset, num_images=5) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) - generic_dataset_test(self, dataset) + self.generic_classification_dataset_test(dataset) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -119,10 +124,14 @@ def test_cifar100(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR100') as root: dataset = torchvision.datasets.CIFAR100(root, train=True, download=True) - generic_dataset_test(self, dataset) + self.generic_classification_dataset_test(dataset) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) - generic_dataset_test(self, dataset) + self.generic_classification_dataset_test(dataset) + img, target = dataset[0] + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) if __name__ == '__main__':