From 9eb4831f742ae6a13b8edb61d07b619392fb6543 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Fri, 3 Jul 2020 03:09:03 -0700 Subject: [PATCH] configurable for dataset_mapper Reviewed By: rbgirshick Differential Revision: D22251484 fbshipit-source-id: 2d2cfb99f40e10b7af4e87a99bea14b3ee98a48c --- detectron2/data/dataset_mapper.py | 114 ++++++++++++++++++++--------- detectron2/data/detection_utils.py | 2 - 2 files changed, 78 insertions(+), 38 deletions(-) diff --git a/detectron2/data/dataset_mapper.py b/detectron2/data/dataset_mapper.py index 2abb1350a0..e3e8a2dca5 100644 --- a/detectron2/data/dataset_mapper.py +++ b/detectron2/data/dataset_mapper.py @@ -2,8 +2,11 @@ import copy import logging import numpy as np +from typing import List, Optional, Union import torch +from detectron2.config import configurable + from . import detection_utils as utils from . import transforms as T @@ -31,38 +34,81 @@ class DatasetMapper: 3. Prepare data and annotations to Tensor and :class:`Instances` """ - def __init__(self, cfg, is_train=True): - self.augmentation = utils.build_augmentation(cfg, is_train) - if cfg.INPUT.CROP.ENABLED and is_train: - self.augmentation.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) - logging.getLogger(__name__).info( - "Cropping used in training: " + str(self.augmentation[0]) - ) - self.compute_tight_boxes = True - else: - self.compute_tight_boxes = False + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False + ): + """ + NOTE: this interface is experimental. + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" # fmt: off - self.img_format = cfg.INPUT.FORMAT - self.mask_on = cfg.MODEL.MASK_ON - self.mask_format = cfg.INPUT.MASK_FORMAT - self.keypoint_on = cfg.MODEL.KEYPOINT_ON - self.load_proposals = cfg.MODEL.LOAD_PROPOSALS + self.is_train = is_train + self.augmentations = augmentations + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes # fmt: on - if self.keypoint_on and is_train: - # Flip only makes sense in training - self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) - else: - self.keypoint_hflip_indices = None + logger = logging.getLogger(__name__) + logger.info("Augmentations used in training: " + str(augmentations)) - if self.load_proposals: - self.proposal_min_box_size = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE - self.proposal_topk = ( + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON + else: + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN if is_train else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST ) - self.is_train = is_train + return ret def __call__(self, dataset_dict): """ @@ -74,7 +120,7 @@ def __call__(self, dataset_dict): """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # USER: Write your own image loading if it's not from a file - image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) utils.check_image_size(dataset_dict, image) # USER: Remove if you don't do semantic/panoptic segmentation. @@ -84,7 +130,7 @@ def __call__(self, dataset_dict): sem_seg_gt = None aug_input = T.StandardAugInput(image, sem_seg=sem_seg_gt) - transforms = aug_input.apply_augmentations(self.augmentation) + transforms = aug_input.apply_augmentations(self.augmentations) image, sem_seg_gt = aug_input.image, aug_input.sem_seg image_shape = image.shape[:2] # h, w @@ -97,13 +143,9 @@ def __call__(self, dataset_dict): # USER: Remove if you don't use pre-computed proposals. # Most users would not need this feature. - if self.load_proposals: + if self.proposal_topk is not None: utils.transform_proposals( - dataset_dict, - image_shape, - transforms, - proposal_topk=self.proposal_topk, - min_box_size=self.proposal_min_box_size, + dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk ) if not self.is_train: @@ -115,9 +157,9 @@ def __call__(self, dataset_dict): if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: - if not self.mask_on: + if not self.use_instance_mask: anno.pop("segmentation", None) - if not self.keypoint_on: + if not self.use_keypoint: anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data @@ -129,7 +171,7 @@ def __call__(self, dataset_dict): if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances( - annos, image_shape, mask_format=self.mask_format + annos, image_shape, mask_format=self.instance_mask_format ) # After transforms such as cropping are applied, the bounding box may no longer @@ -137,7 +179,7 @@ def __call__(self, dataset_dict): # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. - if self.compute_tight_boxes and instances.has("gt_masks"): + if self.recompute_boxes: instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) return dataset_dict diff --git a/detectron2/data/detection_utils.py b/detectron2/data/detection_utils.py index ffd4b3dbf2..7ab4bbdcfc 100644 --- a/detectron2/data/detection_utils.py +++ b/detectron2/data/detection_utils.py @@ -579,12 +579,10 @@ def build_augmentation(cfg, is_train): len(min_size) ) - logger = logging.getLogger(__name__) augmentation = [] augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) if is_train: augmentation.append(T.RandomFlip()) - logger.info("Augmentations used in training: " + str(augmentation)) return augmentation