Skip to content

Commit

Permalink
Multi-scale testing (facebookresearch#804)
Browse files Browse the repository at this point in the history
* Implement multi-scale testing(bbox aug) like Detectron.

* Add comment.

* Fix missing cfg after merge.
  • Loading branch information
fallingdust authored and fmassa committed May 24, 2019
1 parent 5eca57b commit 7a9b185
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 10 deletions.
48 changes: 48 additions & 0 deletions configs/test_time_aug/e2e_mask_rcnn_R_50_FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
TEST:
BBOX_AUG:
ENABLED: True
H_FLIP: True
SCALES: (400, 500, 600, 700, 900, 1000, 1100, 1200)
MAX_SIZE: 2000
SCALE_H_FLIP: True
21 changes: 21 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,27 @@
# Number of detections per image
_C.TEST.DETECTIONS_PER_IMG = 100

# ---------------------------------------------------------------------------- #
# Test-time augmentations for bounding box detection
# See configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_1x.yaml for an example
# ---------------------------------------------------------------------------- #
_C.TEST.BBOX_AUG = CN()

# Enable test-time augmentation for bounding box detection if True
_C.TEST.BBOX_AUG.ENABLED = False

# Horizontal flip at the original scale (id transform)
_C.TEST.BBOX_AUG.H_FLIP = False

# Each scale is the pixel size of an image's shortest side
_C.TEST.BBOX_AUG.SCALES = ()

# Max pixel size of the longer side
_C.TEST.BBOX_AUG.MAX_SIZE = 4000

# Horizontal flip at each scale
_C.TEST.BBOX_AUG.SCALE_H_FLIP = False


# ---------------------------------------------------------------------------- #
# Misc options
Expand Down
8 changes: 5 additions & 3 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import datasets as D
from . import samplers

from .collate_batch import BatchCollator
from .collate_batch import BatchCollator, BBoxAugCollator
from .transforms import build_transforms


Expand Down Expand Up @@ -150,7 +150,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
DatasetCatalog = paths_catalog.DatasetCatalog
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

transforms = build_transforms(cfg, is_train)
# If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

data_loaders = []
Expand All @@ -159,7 +160,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter
)
collator = BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
12 changes: 12 additions & 0 deletions maskrcnn_benchmark/data/collate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ def __call__(self, batch):
targets = transposed_batch[1]
img_ids = transposed_batch[2]
return images, targets, img_ids


class BBoxAugCollator(object):
"""
From a list of samples from the dataset,
returns the images and targets.
Images should be converted to batched images in `im_detect_bbox_aug`
"""

def __call__(self, batch):
return list(zip(*batch))

8 changes: 6 additions & 2 deletions maskrcnn_benchmark/data/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def get_size(self, image_size):

return (oh, ow)

def __call__(self, image, target):
def __call__(self, image, target=None):
size = self.get_size(image.size)
image = F.resize(image, size)
if target is None:
return image
target = target.resize(image.size)
return image, target

Expand Down Expand Up @@ -101,8 +103,10 @@ def __init__(self, mean, std, to_bgr255=True):
self.std = std
self.to_bgr255 = to_bgr255

def __call__(self, image, target):
def __call__(self, image, target=None):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image
return image, target
118 changes: 118 additions & 0 deletions maskrcnn_benchmark/engine/bbox_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
import torchvision.transforms as TT

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import transforms as T
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import make_roi_box_post_processor


def im_detect_bbox_aug(model, images, device):
# Collect detections computed under different transformations
boxlists_ts = []
for _ in range(len(images)):
boxlists_ts.append([])

def add_preds_t(boxlists_t):
for i, boxlist_t in enumerate(boxlists_t):
if len(boxlists_ts[i]) == 0:
# The first one is identity transform, no need to resize the boxlist
boxlists_ts[i].append(boxlist_t)
else:
# Resize the boxlist as the first one
boxlists_ts[i].append(boxlist_t.resize(boxlists_ts[i][0].size))

# Compute detections for the original image (identity transform)
boxlists_i = im_detect_bbox(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_i)

# Perform detection on the horizontally flipped image
if cfg.TEST.BBOX_AUG.H_FLIP:
boxlists_hf = im_detect_bbox_hflip(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_hf)

# Compute detections at different scales
for scale in cfg.TEST.BBOX_AUG.SCALES:
max_size = cfg.TEST.BBOX_AUG.MAX_SIZE
boxlists_scl = im_detect_bbox_scale(
model, images, scale, max_size, device
)
add_preds_t(boxlists_scl)

if cfg.TEST.BBOX_AUG.SCALE_H_FLIP:
boxlists_scl_hf = im_detect_bbox_scale(
model, images, scale, max_size, device, hflip=True
)
add_preds_t(boxlists_scl_hf)

# Merge boxlists detected by different bbox aug params
boxlists = []
for i, boxlist_ts in enumerate(boxlists_ts):
bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts])
scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts])
boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode)
boxlist.add_field('scores', scores)
boxlists.append(boxlist)

# Apply NMS and limit the final detections
results = []
post_processor = make_roi_box_post_processor(cfg)
for boxlist in boxlists:
results.append(post_processor.filter_results(boxlist, cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES))

return results


def im_detect_bbox(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the original image.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
return model(images.to(device))


def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the horizontally flipped image.
Function signature is the same as for im_detect_bbox.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.RandomHorizontalFlip(1.0),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
boxlists = model(images.to(device))

# Invert the detections computed on the flipped image
boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists]
return boxlists_inv


def im_detect_bbox_scale(model, images, target_scale, target_max_size, device, hflip=False):
"""
Computes bbox detections at the given scale.
Returns predictions in the scaled image space.
"""
if hflip:
boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device)
else:
boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device)
return boxlists_scl
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import torch
from tqdm import tqdm

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
from .bbox_aug import im_detect_bbox_aug


def compute_on_dataset(model, data_loader, device, timer=None):
Expand All @@ -19,11 +21,13 @@ def compute_on_dataset(model, data_loader, device, timer=None):
cpu_device = torch.device("cpu")
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if cfg.TEST.BBOX_AUG.ENABLED:
output = im_detect_bbox_aug(model, images, device)
else:
output = model(images.to(device))
if timer:
torch.cuda.synchronize()
timer.toc()
Expand Down
11 changes: 8 additions & 3 deletions maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(
nms=0.5,
detections_per_img=100,
box_coder=None,
cls_agnostic_bbox_reg=False
cls_agnostic_bbox_reg=False,
bbox_aug_enabled=False
):
"""
Arguments:
Expand All @@ -39,6 +40,7 @@ def __init__(
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
self.bbox_aug_enabled = bbox_aug_enabled

def forward(self, x, boxes):
"""
Expand Down Expand Up @@ -79,7 +81,8 @@ def forward(self, x, boxes):
):
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = self.filter_results(boxlist, num_classes)
if not self.bbox_aug_enabled: # If bbox aug is enabled, we will do it later
boxlist = self.filter_results(boxlist, num_classes)
results.append(boxlist)
return results

Expand Down Expand Up @@ -156,12 +159,14 @@ def make_roi_box_post_processor(cfg):
nms_thresh = cfg.MODEL.ROI_HEADS.NMS
detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED

postprocessor = PostProcessor(
score_thresh,
nms_thresh,
detections_per_img,
box_coder,
cls_agnostic_bbox_reg
cls_agnostic_bbox_reg,
bbox_aug_enabled
)
return postprocessor

0 comments on commit 7a9b185

Please sign in to comment.