Skip to content

Commit

Permalink
Add essential scripts for training (#5)
Browse files Browse the repository at this point in the history
* init commit

* Add essential scripts for training

* Minor fixes

* Add doc and type annotations

* Minor fixes

* Fix variables loading

* Fix type annotations

* Minor fixes

* Add type annotations

* Fix type annotations
  • Loading branch information
zhiqwang authored Nov 29, 2020
1 parent 2bbd1f7 commit 32b2d83
Show file tree
Hide file tree
Showing 12 changed files with 2,280 additions and 4 deletions.
34 changes: 34 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.utils.data
import torchvision

from .coco import build as build_coco
from .voc import build as build_voc


def get_coco_api_from_dataset(dataset):
for _ in range(10):
# if isinstance(dataset, torchvision.datasets.CocoDetection):
# break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco


def build_dataset(image_set, dataset_year, args):

datasets = []
for year in dataset_year:
if args.dataset_file == 'coco':
dataset = build_coco(image_set, year, args)
elif args.dataset_file == 'voc':
dataset = build_voc(image_set, year, args)
else:
raise ValueError(f'dataset {args.dataset_file} not supported')
datasets.append(dataset)

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)
204 changes: 204 additions & 0 deletions datasets/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
import os
from pathlib import Path

import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask

from . import transforms as T


class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks

def __call__(self, image, target):
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
# BoxMode: convert from XYWH_ABS to XYXY_ABS
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)

classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)

if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)

keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)

keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]

target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints

# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]

target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])

return image, target


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks):
super().__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)

def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

min_keypoints_per_image = 10

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False

assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


def make_coco_transforms(image_set, image_size=300):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

if image_set == 'train' or image_set == 'trainval':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.Resize(image_size),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.Resize(image_size),
])
),
normalize,
])
elif image_set == 'val' or image_set == 'test':
return T.Compose([
T.Resize(image_size),
normalize,
])
else:
raise ValueError(f'unknown {image_set}')


def build(image_set, year, args):
root = Path(args.data_path)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = args.dataset_mode

img_folder = os.path.join(root, 'images')
ann_file = os.path.join("annotations", f'{mode}_{image_set}{year}.json')
ann_file = os.path.join(root, ann_file)

dataset = CocoDetection(
img_folder,
ann_file,
transforms=make_coco_transforms(image_set, image_size=args.image_size),
return_masks=args.masks,
)

if image_set == 'train':
dataset = _coco_remove_images_without_annotations(dataset)

return dataset
Loading

0 comments on commit 32b2d83

Please sign in to comment.