Skip to content

Commit

Permalink
add tests for CelebA (#3413)
Browse files Browse the repository at this point in the history
Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
pmeier and fmassa authored Feb 22, 2021
1 parent 7f59e8c commit 7b7cfdd
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,5 +639,122 @@ class CIFAR100(CIFAR10TestCase):
)


class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))

CONFIGS = datasets_utils.combinations_grid(
split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
)
REQUIRED_PACKAGES = ("pandas",)

_SPLIT_TO_IDX = dict(train=0, valid=1, test=2)

def inject_fake_data(self, tmpdir, config):
base_folder = pathlib.Path(tmpdir) / "celeba"
os.makedirs(base_folder)

num_images, num_images_per_split = self._create_split_txt(base_folder)

datasets_utils.create_image_folder(
base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images
)
attr_names = self._create_attr_txt(base_folder, num_images)
self._create_identity_txt(base_folder, num_images)
self._create_bbox_txt(base_folder, num_images)
self._create_landmarks_txt(base_folder, num_images)

return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)

def _create_split_txt(self, root):
num_images_per_split = dict(train=3, valid=2, test=1)

data = [
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
]
self._create_txt(root, "list_eval_partition.txt", data)

num_images_per_split["all"] = num_images = sum(num_images_per_split.values())
return num_images, num_images_per_split

def _create_attr_txt(self, root, num_images):
header = ("5_o_Clock_Shadow", "Young")
data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist()
self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True)
return header

def _create_identity_txt(self, root, num_images):
data = torch.randint(1, 4, size=(num_images, 1)).tolist()
self._create_txt(root, "identity_CelebA.txt", data)

def _create_bbox_txt(self, root, num_images):
header = ("x_1", "y_1", "width", "height")
data = torch.randint(10, size=(num_images, len(header))).tolist()
self._create_txt(
root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True
)

def _create_landmarks_txt(self, root, num_images):
header = ("lefteye_x", "rightmouth_y")
data = torch.randint(10, size=(num_images, len(header))).tolist()
self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True)

def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False):
with open(pathlib.Path(root) / name, "w") as fh:
if add_num_examples:
fh.write(f"{len(data)}\n")

if header:
if add_image_id_to_header:
header = ("image_id", *header)
fh.write(f"{' '.join(header)}\n")

for idx, line in enumerate(data, 1):
fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n")

def test_combined_targets(self):
target_types = ["attr", "identity", "bbox", "landmarks"]

individual_targets = []
for target_type in target_types:
with self.create_dataset(target_type=target_type) as (dataset, _):
_, target = dataset[0]
individual_targets.append(target)

with self.create_dataset(target_type=target_types) as (dataset, _):
_, combined_targets = dataset[0]

actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)

for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)

def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0]

self.assertIsNone(target)

def test_attr_names(self):
with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 7b7cfdd

Please sign in to comment.