-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add essential scripts for training (#5)
* init commit * Add essential scripts for training * Minor fixes * Add doc and type annotations * Minor fixes * Fix variables loading * Fix type annotations * Minor fixes * Add type annotations * Fix type annotations
- Loading branch information
Showing
12 changed files
with
2,280 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import torch.utils.data | ||
import torchvision | ||
|
||
from .coco import build as build_coco | ||
from .voc import build as build_voc | ||
|
||
|
||
def get_coco_api_from_dataset(dataset): | ||
for _ in range(10): | ||
# if isinstance(dataset, torchvision.datasets.CocoDetection): | ||
# break | ||
if isinstance(dataset, torch.utils.data.Subset): | ||
dataset = dataset.dataset | ||
if isinstance(dataset, torchvision.datasets.CocoDetection): | ||
return dataset.coco | ||
|
||
|
||
def build_dataset(image_set, dataset_year, args): | ||
|
||
datasets = [] | ||
for year in dataset_year: | ||
if args.dataset_file == 'coco': | ||
dataset = build_coco(image_set, year, args) | ||
elif args.dataset_file == 'voc': | ||
dataset = build_voc(image_set, year, args) | ||
else: | ||
raise ValueError(f'dataset {args.dataset_file} not supported') | ||
datasets.append(dataset) | ||
|
||
if len(datasets) == 1: | ||
return datasets[0] | ||
else: | ||
return torch.utils.data.ConcatDataset(datasets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
COCO dataset which returns image_id for evaluation. | ||
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | ||
""" | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch.utils.data | ||
import torchvision | ||
from pycocotools import mask as coco_mask | ||
|
||
from . import transforms as T | ||
|
||
|
||
class ConvertCocoPolysToMask(object): | ||
def __init__(self, return_masks=False): | ||
self.return_masks = return_masks | ||
|
||
def __call__(self, image, target): | ||
w, h = image.size | ||
|
||
image_id = target["image_id"] | ||
image_id = torch.tensor([image_id]) | ||
|
||
anno = target["annotations"] | ||
|
||
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] | ||
|
||
boxes = [obj["bbox"] for obj in anno] | ||
# guard against no boxes via resizing | ||
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | ||
# BoxMode: convert from XYWH_ABS to XYXY_ABS | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes[:, 0::2].clamp_(min=0, max=w) | ||
boxes[:, 1::2].clamp_(min=0, max=h) | ||
|
||
classes = [obj["category_id"] for obj in anno] | ||
classes = torch.tensor(classes, dtype=torch.int64) | ||
|
||
if self.return_masks: | ||
segmentations = [obj["segmentation"] for obj in anno] | ||
masks = convert_coco_poly_to_mask(segmentations, h, w) | ||
|
||
keypoints = None | ||
if anno and "keypoints" in anno[0]: | ||
keypoints = [obj["keypoints"] for obj in anno] | ||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | ||
num_keypoints = keypoints.shape[0] | ||
if num_keypoints: | ||
keypoints = keypoints.view(num_keypoints, -1, 3) | ||
|
||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | ||
boxes = boxes[keep] | ||
classes = classes[keep] | ||
if self.return_masks: | ||
masks = masks[keep] | ||
if keypoints is not None: | ||
keypoints = keypoints[keep] | ||
|
||
target = {} | ||
target["boxes"] = boxes | ||
target["labels"] = classes | ||
if self.return_masks: | ||
target["masks"] = masks | ||
target["image_id"] = image_id | ||
if keypoints is not None: | ||
target["keypoints"] = keypoints | ||
|
||
# for conversion to coco api | ||
area = torch.tensor([obj["area"] for obj in anno]) | ||
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | ||
target["area"] = area[keep] | ||
target["iscrowd"] = iscrowd[keep] | ||
|
||
target["orig_size"] = torch.as_tensor([int(h), int(w)]) | ||
target["size"] = torch.as_tensor([int(h), int(w)]) | ||
|
||
return image, target | ||
|
||
|
||
class CocoDetection(torchvision.datasets.CocoDetection): | ||
def __init__(self, img_folder, ann_file, transforms, return_masks): | ||
super().__init__(img_folder, ann_file) | ||
self._transforms = transforms | ||
self.prepare = ConvertCocoPolysToMask(return_masks) | ||
|
||
def __getitem__(self, idx): | ||
img, target = super().__getitem__(idx) | ||
image_id = self.ids[idx] | ||
target = {'image_id': image_id, 'annotations': target} | ||
img, target = self.prepare(img, target) | ||
if self._transforms is not None: | ||
img, target = self._transforms(img, target) | ||
return img, target | ||
|
||
|
||
def convert_coco_poly_to_mask(segmentations, height, width): | ||
masks = [] | ||
for polygons in segmentations: | ||
rles = coco_mask.frPyObjects(polygons, height, width) | ||
mask = coco_mask.decode(rles) | ||
if len(mask.shape) < 3: | ||
mask = mask[..., None] | ||
mask = torch.as_tensor(mask, dtype=torch.uint8) | ||
mask = mask.any(dim=2) | ||
masks.append(mask) | ||
if masks: | ||
masks = torch.stack(masks, dim=0) | ||
else: | ||
masks = torch.zeros((0, height, width), dtype=torch.uint8) | ||
return masks | ||
|
||
|
||
def _coco_remove_images_without_annotations(dataset, cat_list=None): | ||
def _has_only_empty_bbox(anno): | ||
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | ||
|
||
def _count_visible_keypoints(anno): | ||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | ||
|
||
min_keypoints_per_image = 10 | ||
|
||
def _has_valid_annotation(anno): | ||
# if it's empty, there is no annotation | ||
if len(anno) == 0: | ||
return False | ||
# if all boxes have close to zero area, there is no annotation | ||
if _has_only_empty_bbox(anno): | ||
return False | ||
# keypoints task have a slight different critera for considering | ||
# if an annotation is valid | ||
if "keypoints" not in anno[0]: | ||
return True | ||
# for keypoint detection tasks, only consider valid images those | ||
# containing at least min_keypoints_per_image | ||
if _count_visible_keypoints(anno) >= min_keypoints_per_image: | ||
return True | ||
return False | ||
|
||
assert isinstance(dataset, torchvision.datasets.CocoDetection) | ||
ids = [] | ||
for ds_idx, img_id in enumerate(dataset.ids): | ||
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) | ||
anno = dataset.coco.loadAnns(ann_ids) | ||
if cat_list: | ||
anno = [obj for obj in anno if obj["category_id"] in cat_list] | ||
if _has_valid_annotation(anno): | ||
ids.append(ds_idx) | ||
|
||
dataset = torch.utils.data.Subset(dataset, ids) | ||
return dataset | ||
|
||
|
||
def make_coco_transforms(image_set, image_size=300): | ||
|
||
normalize = T.Compose([ | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
]) | ||
|
||
if image_set == 'train' or image_set == 'trainval': | ||
return T.Compose([ | ||
T.RandomHorizontalFlip(), | ||
T.RandomSelect( | ||
T.Resize(image_size), | ||
T.Compose([ | ||
T.RandomResize([400, 500, 600]), | ||
T.RandomSizeCrop(384, 600), | ||
T.Resize(image_size), | ||
]) | ||
), | ||
normalize, | ||
]) | ||
elif image_set == 'val' or image_set == 'test': | ||
return T.Compose([ | ||
T.Resize(image_size), | ||
normalize, | ||
]) | ||
else: | ||
raise ValueError(f'unknown {image_set}') | ||
|
||
|
||
def build(image_set, year, args): | ||
root = Path(args.data_path) | ||
assert root.exists(), f'provided COCO path {root} does not exist' | ||
mode = args.dataset_mode | ||
|
||
img_folder = os.path.join(root, 'images') | ||
ann_file = os.path.join("annotations", f'{mode}_{image_set}{year}.json') | ||
ann_file = os.path.join(root, ann_file) | ||
|
||
dataset = CocoDetection( | ||
img_folder, | ||
ann_file, | ||
transforms=make_coco_transforms(image_set, image_size=args.image_size), | ||
return_masks=args.masks, | ||
) | ||
|
||
if image_set == 'train': | ||
dataset = _coco_remove_images_without_annotations(dataset) | ||
|
||
return dataset |
Oops, something went wrong.