Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added detection target types to Oxford III Pet dataset #8425

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2553,7 +2553,7 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):

ADDITIONAL_CONFIGS = combinations_grid(
split=("trainval", "test"),
target_types=("category", "binary-category", "segmentation", ["category", "segmentation"], []),
target_types=("category", "binary-category", "detection", "binary-detection", "segmentation", ["category", "segmentation"], []),
)

def inject_fake_data(self, tmpdir, config):
Expand Down
43 changes: 42 additions & 1 deletion torchvision/datasets/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import os.path
import pathlib
import torch
from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from PIL import Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
from .voc import VOCDetection


class OxfordIIITPet(VisionDataset):
Expand All @@ -20,6 +23,8 @@ class OxfordIIITPet(VisionDataset):

- ``category`` (int): Label for one of the 37 pet categories.
- ``binary-category`` (int): Binary label for cat or dog.
- ``detection`` (dict): Pascal VOC annotation dict with classes covering the 37 pet breeds
- ``binary-detection`` Pascal VOC annotation dict with binary cat/dog classes
- ``segmentation`` (PIL image): Segmentation trimap of the image.

If empty, ``None`` will be returned as target.
Expand All @@ -35,7 +40,7 @@ class OxfordIIITPet(VisionDataset):
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
)
_VALID_TARGET_TYPES = ("category", "binary-category", "segmentation")
_VALID_TARGET_TYPES = ("category", "binary-category", "detection", "binary-detection", "segmentation")

def __init__(
self,
Expand All @@ -59,6 +64,7 @@ def __init__(
self._images_folder = self._base_folder / "images"
self._anns_folder = self._base_folder / "annotations"
self._segs_folder = self._anns_folder / "trimaps"
self._xmls_folder = self._anns_folder / "xmls"

if download:
self._download()
Expand Down Expand Up @@ -89,6 +95,20 @@ def __init__(

self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
self._xmls = [self._xmls_folder / f"{image_id}.xml" for image_id in image_ids]

# The oxford pet dataset has detection XMLs in VOC format, but some images do not have xmls
# Here we filter to only samples that have corresponding xml files when detection is selected
if "detection" in target_types or "binary-detection" in target_types:
# Notify users this is not a complete dataset
print('Dataset does not contain detection annotations for every sample. Filtering to include' \
' only those that do.') # TODO: Is a simple print the right call here?
# Set up filtered arrays
self._labels = [lbl for lbl,xml_file in zip(self._labels,self._xmls) if os.path.isfile(xml_file)]
self._bin_labels = [lbl for lbl,xml_file in zip(self._bin_labels,self._xmls) if os.path.isfile(xml_file)]
self._images = [img for img,xml_file in zip(self._images,self._xmls) if os.path.isfile(xml_file)]
self._segs = [seg for seg,xml_file in zip(self._segs,self._xmls) if os.path.isfile(xml_file)]
self._xmls = [xml_file for xml_file in self._xmls if os.path.isfile(xml_file)]

def __len__(self) -> int:
return len(self._images)
Expand All @@ -102,6 +122,14 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
target.append(self._labels[idx])
elif target_type == "binary-category":
target.append(self._bin_labels[idx])
elif target_type == "detection":
target.append(self._to_rcnn(VOCDetection.parse_voc_xml(ET_parse(self._xmls[idx]).getroot() \
),self._labels[idx]))
#target[-1]['annotation']['object'][0]['name'] = self.classes[self._labels[idx]]
elif target_type == "binary-detection":
target.append(self._to_rcnn(VOCDetection.parse_voc_xml(ET_parse(self._xmls[idx]).getroot() \
),self._bin_labels[idx]))
#target[-1]['annotation']['object'][0]['name'] = self.classes[self._bin_labels[idx]]
else: # target_type == "segmentation"
target.append(Image.open(self._segs[idx]))

Expand Down Expand Up @@ -130,3 +158,16 @@ def _download(self) -> None:

for url, md5 in self._RESOURCES:
download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)

def _to_rcnn(self, anno_dict: dict,label: int) -> dict:
# Create output tensors
out = {'boxes': torch.empty(1,4,dtype=torch.float32),
'labels':torch.empty(1,dtype=torch.int64)}
# Populate output
out['boxes'][0,0] = float(anno_dict['annotation']['object'][0]['bndbox']['xmin'])
out['boxes'][0,1] = float(anno_dict['annotation']['object'][0]['bndbox']['ymin'])
out['boxes'][0,2] = float(anno_dict['annotation']['object'][0]['bndbox']['xmax'])
out['boxes'][0,3] = float(anno_dict['annotation']['object'][0]['bndbox']['ymax'])
out['labels'][0] = label

return out