Skip to content

Commit

Permalink
Add widerface dataset (#2883)
Browse files Browse the repository at this point in the history
* initial commit of widerface dataset

* comment out old code

* improve parsing of annotation files

* code cleanup and fix docstring comments

* speed up check for quota exceeded

* cleanup print statements

* reformat code and remove print statements

* minor code cleanup and reformatting

* add more comments

* reuse variable

* reverse formatting changes

* fix flake8 errors

* add type annotations

* fix mypy errors

* add a base_folder to root directory

* some formatting fixes

* GDrive threshold does not throw 403 error

* testing new download logic

* cleanup logic for download and integrity check

* use a better variable name

* format fix

* reorder list in docstring

* initial widerface unit test - fails on MD5 check

* use list of dictionaries to store dataset

* fix docstring formatting

* remove unnecessary error checking

* fix type checker error

* revert typo fix

* rename var constants, use file context manager, verify str args

* fix flake8 error

* fix checking target_type argument values

* create uncompressed dataset folders

* cleanup unit tests for widerface

* use correct os function

* add more info to docstring

* disable unittests for windows

* fix _check_integrity logic

* update docstring

* remove citation

* remove target_type option

* fix formatting issue

Co-authored-by: Philip Meier <[email protected]>

* remove comment and add more info to docstring

* update type annotations

* restart CI jobs

Co-authored-by: Joshua Bradley <[email protected]>
Co-authored-by: Philip Meier <[email protected]>
Co-authored-by: vfdev <[email protected]>
  • Loading branch information
4 people authored Jan 11, 2021
1 parent 3ee34eb commit d0063f3
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 3 deletions.
67 changes: 67 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,73 @@ def _make_devkit_archive(root):
yield root


@contextlib.contextmanager
def widerface_root():
"""
Generates a dataset with the following folder structure and returns the path root:
<root>
└── widerface
├── wider_face_split
├── WIDER_train
├── WIDER_val
└── WIDER_test
The dataset consist of
1 image for each dataset split (train, val, test) and annotation files
for each split
"""

def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)

def _make_train_archive(root):
extracted_dir = os.path.join(root, 'WIDER_train', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_1.jpg'))

def _make_val_archive(root):
extracted_dir = os.path.join(root, 'WIDER_val', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_2.jpg'))

def _make_test_archive(root):
extracted_dir = os.path.join(root, 'WIDER_test', 'images', '0--Parade')
os.makedirs(extracted_dir)
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_3.jpg'))

def _make_annotations_archive(root):
train_bbox_contents = '0--Parade/0_Parade_marchingband_1_1.jpg\n1\n449 330 122 149 0 0 0 0 0 0\n'
val_bbox_contents = '0--Parade/0_Parade_marchingband_1_2.jpg\n1\n501 160 285 443 0 0 0 0 0 0\n'
test_filelist_contents = '0--Parade/0_Parade_marchingband_1_3.jpg\n'
extracted_dir = os.path.join(root, 'wider_face_split')
os.mkdir(extracted_dir)

# bbox training file
bbox_file = os.path.join(extracted_dir, "wider_face_train_bbx_gt.txt")
with open(bbox_file, "w") as txt_file:
txt_file.write(train_bbox_contents)

# bbox validation file
bbox_file = os.path.join(extracted_dir, "wider_face_val_bbx_gt.txt")
with open(bbox_file, "w") as txt_file:
txt_file.write(val_bbox_contents)

# test filelist file
filelist_file = os.path.join(extracted_dir, "wider_face_test_filelist.txt")
with open(filelist_file, "w") as txt_file:
txt_file.write(test_filelist_contents)

with get_tmp_dir() as root:
root_base = os.path.join(root, "widerface")
os.mkdir(root_base)
_make_train_archive(root_base)
_make_val_archive(root_base)
_make_test_archive(root_base)
_make_annotations_archive(root_base)

yield root


@contextlib.contextmanager
def cityscapes_root():

Expand Down
22 changes: 21 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -139,6 +139,26 @@ def test_imagenet(self, mock_verify):
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)

@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_widerface(self, mock_check_integrity):
mock_check_integrity.return_value = True
with widerface_root() as root:
dataset = torchvision.datasets.WIDERFace(root, split='train')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.WIDERFace(root, split='val')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

dataset = torchvision.datasets.WIDERFace(root, split='test')
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))

@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
Expand Down
6 changes: 4 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .imagenet import ImageNet
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .widerface import WIDERFace
from .sbd import SBDataset
from .vision import VisionDataset
from .usps import USPS
Expand All @@ -31,5 +32,6 @@
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
'Places365')
183 changes: 183 additions & 0 deletions torchvision/datasets/widerface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from PIL import Image
import os
from os.path import abspath, expanduser
import torch
from typing import Any, Callable, List, Dict, Optional, Tuple, Union
from .utils import check_integrity, download_file_from_google_drive, \
download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset


class WIDERFace(VisionDataset):
"""`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
Args:
root (string): Root directory where images and annotations are downloaded to.
Expects the following folder structure if download=False:
<root>
└── widerface
├── wider_face_split ('wider_face_split.zip' if compressed)
├── WIDER_train ('WIDER_train.zip' if compressed)
├── WIDER_val ('WIDER_val.zip' if compressed)
└── WIDER_test ('WIDER_test.zip' if compressed)
split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

BASE_FOLDER = "widerface"
FILE_LIST = [
# File ID MD5 Hash Filename
("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip")
]
ANNOTATIONS_FILE = (
"http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip",
"0e3767bcf0e326556d407bf5bff5d27c",
"wider_face_split.zip"
)

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER),
transform=transform,
target_transform=target_transform)
# check arguments
self.split = verify_str_arg(split, "split", ("train", "val", "test"))

if download:
self.download()

if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. " +
"You can use download=True to download and prepare it")

self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"):
self.parse_train_val_annotations_file()
else:
self.parse_test_annotations_file()

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dict of annotations for all faces in the image.
target=None for the test split.
"""

# stay consistent with other datasets and return a PIL Image
img = Image.open(self.img_info[index]["img_path"])

if self.transform is not None:
img = self.transform(img)

target = None if self.split == "test" else self.img_info[index]["annotations"]
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self) -> int:
return len(self.img_info)

def extra_repr(self) -> str:
lines = ["Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)

def parse_train_val_annotations_file(self) -> None:
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
filepath = os.path.join(self.root, "wider_face_split", filename)

with open(filepath, "r") as f:
lines = f.readlines()
file_name_line, num_boxes_line, box_annotation_line = True, False, False
num_boxes, box_counter = 0, 0
labels = []
for line in lines:
line = line.rstrip()
if file_name_line:
img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
img_path = abspath(expanduser(img_path))
file_name_line = False
num_boxes_line = True
elif num_boxes_line:
num_boxes = int(line)
num_boxes_line = False
box_annotation_line = True
elif box_annotation_line:
box_counter += 1
line_split = line.split(" ")
line_values = [int(x) for x in line_split]
labels.append(line_values)
if box_counter >= num_boxes:
box_annotation_line = False
file_name_line = True
labels_tensor = torch.tensor(labels)
self.img_info.append({
"img_path": img_path,
"annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9]}
})
box_counter = 0
labels.clear()
else:
raise RuntimeError("Error parsing annotation file {}".format(filepath))

def parse_test_annotations_file(self) -> None:
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
filepath = abspath(expanduser(filepath))
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.rstrip()
img_path = os.path.join(self.root, "WIDER_test", "images", line)
img_path = abspath(expanduser(img_path))
self.img_info.append({"img_path": img_path})

def _check_integrity(self) -> bool:
# Allow original archive to be deleted (zip). Only need the extracted images
all_files = self.FILE_LIST.copy()
all_files.append(self.ANNOTATIONS_FILE)
for (_, md5, filename) in all_files:
file, ext = os.path.splitext(filename)
extracted_dir = os.path.join(self.root, file)
if not os.path.exists(extracted_dir):
return False
return True

def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return

# download and extract image data
for (file_id, md5, filename) in self.FILE_LIST:
download_file_from_google_drive(file_id, self.root, filename, md5)
filepath = os.path.join(self.root, filename)
extract_archive(filepath)

# download and extract annotation files
download_and_extract_archive(url=self.ANNOTATIONS_FILE[0],
download_root=self.root,
md5=self.ANNOTATIONS_FILE[1])

0 comments on commit d0063f3

Please sign in to comment.