diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index b0b0aec7cbf..7c871054e4e 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -29,7 +29,7 @@ def _make_image_file(filename, num_images): f.write(img.numpy().tobytes()) def _make_label_file(filename, num_images): - labels = torch.randint(0, 10, size=(num_images,), dtype=torch.uint8) + labels = torch.zeros((num_images,), dtype=torch.uint8) with open(filename, "wb") as f: f.write(_encode(2049)) # magic header f.write(_encode(num_images)) diff --git a/test/test_datasets.py b/test/test_datasets.py index 19f2428db6a..861c5dc0f6c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -10,6 +10,12 @@ class Tester(unittest.TestCase): + def generic_classification_dataset_test(self, dataset, num_images=1): + self.assertEqual(len(dataset), num_images) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + def test_imagefolder(self): # TODO: create the fake data on-the-fly FAKEDATA_DIR = get_file_path_2( @@ -64,47 +70,36 @@ def test_mnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) - self.assertEqual(len(dataset), num_examples) + self.generic_classification_dataset_test(dataset, num_images=num_examples) img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_kmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) + self.generic_classification_dataset_test(dataset, num_images=num_examples) img, target = dataset[0] - self.assertEqual(len(dataset), num_examples) - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_fashionmnist(self, mock_download_extract): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) + self.generic_classification_dataset_test(dataset, num_images=num_examples) img, target = dataset[0] - self.assertEqual(len(dataset), num_examples) - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.utils.download_url') def test_imagenet(self, mock_download): with imagenet_root() as root: dataset = torchvision.datasets.ImageNet(root, split='train', download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.generic_classification_dataset_test(dataset) dataset = torchvision.datasets.ImageNet(root, split='val', download=True) - self.assertEqual(len(dataset), 1) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.generic_classification_dataset_test(dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -113,18 +108,14 @@ def test_cifar10(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR10') as root: dataset = torchvision.datasets.CIFAR10(root, train=True, download=True) - self.assertEqual(len(dataset), 5) + self.generic_classification_dataset_test(dataset, num_images=5) img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) - self.assertEqual(len(dataset), 1) + self.generic_classification_dataset_test(dataset) img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @@ -133,18 +124,14 @@ def test_cifar100(self, mock_ext_check, mock_int_check): mock_int_check.return_value = True with cifar_root('CIFAR100') as root: dataset = torchvision.datasets.CIFAR100(root, train=True, download=True) - self.assertEqual(len(dataset), 1) + self.generic_classification_dataset_test(dataset) img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) - self.assertEqual(len(dataset), 1) + self.generic_classification_dataset_test(dataset) img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - self.assertEqual(dataset.class_to_idx['fakedata'], target) + self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) if __name__ == '__main__':