Skip to content

Commit

Permalink
Adding fvgc_aircraft dataset (#5178)
Browse files Browse the repository at this point in the history
* add fvgc_aircraft dataset

* add docstring & remove useless import

* resolve lint issue

* address comments

* adding more annotation level

* nit

* address comments

* Apply suggestions from code review

* unify format

* remove useless line

Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
yiwen-song and pmeier authored Jan 14, 2022
1 parent 1feb637 commit adf8466
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
FGVCAircraft
GTSRB
HD1K
HMDB51
Expand Down
51 changes: 51 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,57 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
annotation_level = config["annotation_level"]
annotation_level_to_file = {
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}

root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
data_folder = root_folder / "data"

classes = ["707-320", "Hawk T1", "Tornado"]
num_images_per_class = 5

datasets_utils.create_image_folder(
data_folder,
"images",
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_images_per_class * len(classes),
)

annotation_file = data_folder / annotation_level_to_file[annotation_level]
with open(annotation_file, "w") as file:
file.write("\n".join(classes))

num_samples_per_class = 4 if split == "trainval" else 2
images_classes = []
for i in range(len(classes)):
images_classes.extend(
[
f"{idx} {classes[i]}"
for idx in random.sample(
range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class
)
]
)

images_annotation_file = data_folder / f"images_{annotation_level}_{split}.txt"
with open(images_annotation_file, "w") as file:
file.write("\n".join(images_classes))

return len(classes * num_samples_per_class)


class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dtd import DTD
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr8k, Flickr30k
from .flowers102 import Flowers102
from .folder import ImageFolder, DatasetFolder
Expand Down Expand Up @@ -95,4 +96,5 @@
"CLEVRClassification",
"OxfordIIITPet",
"Country211",
"FGVCAircraft",
)
114 changes: 114 additions & 0 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import os
from typing import Any, Callable, Optional, Tuple

import PIL.Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class FGVCAircraft(VisionDataset):
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
The dataset contains 10,200 images of aircraft, with 100 images for each of 102
different aircraft model variants, most of which are airplanes.
Aircraft models are organized in a three-levels hierarchy. The three levels, from
finer to coarser, are:
- ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
indistinguishable into one class. The dataset comprises 102 different variants.
- ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
- ``manufacturer``, e.g. Boeing. The dataset comprises 41 different manufacturers.
Args:
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"

def __init__(
self,
root: str,
split: str = "trainval",
download: bool = False,
annotation_level: str = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._annotation_level = verify_str_arg(
annotation_level, "annotation_level", ("variant", "family", "manufacturer")
)

self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

annotation_file = os.path.join(
self._data_path,
"data",
{
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}[self._annotation_level],
)
with open(annotation_file, "r") as f:
self.classes = [line.strip() for line in f]

self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))

image_data_folder = os.path.join(self._data_path, "data", "images")
labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")

self._image_files = []
self._labels = []

with open(labels_file, "r") as f:
for line in f:
image_name, label_name = line.strip().split(" ", 1)
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
self._labels.append(self.class_to_idx[label_name])

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def _download(self) -> None:
"""
Download the FGVC Aircraft dataset archive and extract it under root.
"""
if self._check_exists():
return
download_and_extract_archive(self._URL, self.root)

def _check_exists(self) -> bool:
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

0 comments on commit adf8466

Please sign in to comment.