Skip to content

Commit

Permalink
add tests for Coco (#3416)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Francisco Massa <[email protected]>

Reviewed By: NicolasHug

Differential Revision: D26605326

fbshipit-source-id: df7d6e8c4a50d43b432906f643c55345e0d85915
  • Loading branch information
fmassa authored and facebook-github-bot committed Feb 23, 2021
1 parent 8f70d1c commit a3107fa
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchvision import datasets
import torch
import shutil
import json


try:
Expand Down Expand Up @@ -839,5 +840,70 @@ def test_annotations(self):
self.assertEqual(object, info["annotation"])


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection
FEATURE_TYPES = (PIL.Image.Image, list)

REQUIRED_PACKAGES = ("pycocotools",)

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

num_images = 3
num_annotations_per_image = 2

image_folder = tmpdir / "images"
files = datasets_utils.create_image_folder(
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
)
file_names = [file.relative_to(image_folder) for file in files]

annotation_folder = tmpdir / "annotations"
os.makedirs(annotation_folder)
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)

info["num_examples"] = num_images
return (str(image_folder), str(annotation_file)), info

def _create_annotation_file(self, root, file_names, num_annotations_per_image):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]

annotations, info = self._create_annotations(image_ids, num_annotations_per_image)

content = dict(images=images, annotations=annotations)
return self._create_json(root, "annotations.json", content), info

def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
)
for id, annotation in enumerate(annotations):
annotation["id"] = id
return annotations, dict()

def _create_json(self, root, name, content):
file = pathlib.Path(root) / name
with open(file, "w") as fh:
json.dump(content, fh)
return file


class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions

def _create_annotations(self, image_ids, num_annotations_per_image):
captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations):
annotation["id"] = id
return annotations, dict(captions=captions)

def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"]))


if __name__ == "__main__":
unittest.main()

0 comments on commit a3107fa

Please sign in to comment.