Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add essential scripts for training #5

Merged
merged 10 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.utils.data
import torchvision

from .coco import build as build_coco
from .voc import build as build_voc


def get_coco_api_from_dataset(dataset):
for _ in range(10):
# if isinstance(dataset, torchvision.datasets.CocoDetection):
# break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco


def build_dataset(image_set, dataset_year, args):

datasets = []
for year in dataset_year:
if args.dataset_file == 'coco':
dataset = build_coco(image_set, year, args)
elif args.dataset_file == 'voc':
dataset = build_voc(image_set, year, args)
else:
raise ValueError(f'dataset {args.dataset_file} not supported')
datasets.append(dataset)

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)
204 changes: 204 additions & 0 deletions datasets/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
import os
from pathlib import Path

import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask

from . import transforms as T


class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks

def __call__(self, image, target):
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
# BoxMode: convert from XYWH_ABS to XYXY_ABS
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)

classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)

if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)

keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)

keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]

target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints

# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]

target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])

return image, target


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks):
super().__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)

def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

min_keypoints_per_image = 10

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False

assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


def make_coco_transforms(image_set, image_size=300):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

if image_set == 'train' or image_set == 'trainval':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.Resize(image_size),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.Resize(image_size),
])
),
normalize,
])
elif image_set == 'val' or image_set == 'test':
return T.Compose([
T.Resize(image_size),
normalize,
])
else:
raise ValueError(f'unknown {image_set}')


def build(image_set, year, args):
root = Path(args.data_path)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = args.dataset_mode

img_folder = os.path.join(root, 'images')
ann_file = os.path.join("annotations", f'{mode}_{image_set}{year}.json')
ann_file = os.path.join(root, ann_file)

dataset = CocoDetection(
img_folder,
ann_file,
transforms=make_coco_transforms(image_set, image_size=args.image_size),
return_masks=args.masks,
)

if image_set == 'train':
dataset = _coco_remove_images_without_annotations(dataset)

return dataset
Loading