From 7b7cfdd4f71b1b918da343715ac139e3d65b86e8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 22 Feb 2021 11:54:54 +0100 Subject: [PATCH] add tests for CelebA (#3413) Co-authored-by: Francisco Massa --- test/test_datasets.py | 117 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index 9e761c03aef..ca9217e0bf7 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -639,5 +639,122 @@ class CIFAR100(CIFAR10TestCase): ) +class CelebATestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CelebA + FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None))) + + CONFIGS = datasets_utils.combinations_grid( + split=("train", "valid", "test", "all"), + target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), + ) + REQUIRED_PACKAGES = ("pandas",) + + _SPLIT_TO_IDX = dict(train=0, valid=1, test=2) + + def inject_fake_data(self, tmpdir, config): + base_folder = pathlib.Path(tmpdir) / "celeba" + os.makedirs(base_folder) + + num_images, num_images_per_split = self._create_split_txt(base_folder) + + datasets_utils.create_image_folder( + base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images + ) + attr_names = self._create_attr_txt(base_folder, num_images) + self._create_identity_txt(base_folder, num_images) + self._create_bbox_txt(base_folder, num_images) + self._create_landmarks_txt(base_folder, num_images) + + return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names) + + def _create_split_txt(self, root): + num_images_per_split = dict(train=3, valid=2, test=1) + + data = [ + [self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images) + ] + self._create_txt(root, "list_eval_partition.txt", data) + + num_images_per_split["all"] = num_images = sum(num_images_per_split.values()) + return num_images, num_images_per_split + + def _create_attr_txt(self, root, num_images): + header = ("5_o_Clock_Shadow", "Young") + data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist() + self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True) + return header + + def _create_identity_txt(self, root, num_images): + data = torch.randint(1, 4, size=(num_images, 1)).tolist() + self._create_txt(root, "identity_CelebA.txt", data) + + def _create_bbox_txt(self, root, num_images): + header = ("x_1", "y_1", "width", "height") + data = torch.randint(10, size=(num_images, len(header))).tolist() + self._create_txt( + root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True + ) + + def _create_landmarks_txt(self, root, num_images): + header = ("lefteye_x", "rightmouth_y") + data = torch.randint(10, size=(num_images, len(header))).tolist() + self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True) + + def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False): + with open(pathlib.Path(root) / name, "w") as fh: + if add_num_examples: + fh.write(f"{len(data)}\n") + + if header: + if add_image_id_to_header: + header = ("image_id", *header) + fh.write(f"{' '.join(header)}\n") + + for idx, line in enumerate(data, 1): + fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n") + + def test_combined_targets(self): + target_types = ["attr", "identity", "bbox", "landmarks"] + + individual_targets = [] + for target_type in target_types: + with self.create_dataset(target_type=target_type) as (dataset, _): + _, target = dataset[0] + individual_targets.append(target) + + with self.create_dataset(target_type=target_types) as (dataset, _): + _, combined_targets = dataset[0] + + actual = len(individual_targets) + expected = len(combined_targets) + self.assertEqual( + actual, + expected, + f"The number of the returned combined targets does not match the the number targets if requested " + f"individually: {actual} != {expected}", + ) + + for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets): + with self.subTest(target_type=target_type): + actual = type(combined_target) + expected = type(individual_target) + self.assertIs( + actual, + expected, + f"Type of the combined target does not match the type of the corresponding individual target: " + f"{actual} is not {expected}", + ) + + def test_no_target(self): + with self.create_dataset(target_type=[]) as (dataset, _): + _, target = dataset[0] + + self.assertIsNone(target) + + def test_attr_names(self): + with self.create_dataset() as (dataset, info): + self.assertEqual(tuple(dataset.attr_names), info["attr_names"]) + + if __name__ == "__main__": unittest.main()