forked from facebookresearch/maskrcnn-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multi-scale testing (facebookresearch#804)
* Implement multi-scale testing(bbox aug) like Detectron. * Add comment. * Fix missing cfg after merge.
- Loading branch information
1 parent
5eca57b
commit 7a9b185
Showing
8 changed files
with
224 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
MODEL: | ||
META_ARCHITECTURE: "GeneralizedRCNN" | ||
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" | ||
BACKBONE: | ||
CONV_BODY: "R-50-FPN" | ||
RESNETS: | ||
BACKBONE_OUT_CHANNELS: 256 | ||
RPN: | ||
USE_FPN: True | ||
ANCHOR_STRIDE: (4, 8, 16, 32, 64) | ||
PRE_NMS_TOP_N_TRAIN: 2000 | ||
PRE_NMS_TOP_N_TEST: 1000 | ||
POST_NMS_TOP_N_TEST: 1000 | ||
FPN_POST_NMS_TOP_N_TEST: 1000 | ||
ROI_HEADS: | ||
USE_FPN: True | ||
ROI_BOX_HEAD: | ||
POOLER_RESOLUTION: 7 | ||
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) | ||
POOLER_SAMPLING_RATIO: 2 | ||
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" | ||
PREDICTOR: "FPNPredictor" | ||
ROI_MASK_HEAD: | ||
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) | ||
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" | ||
PREDICTOR: "MaskRCNNC4Predictor" | ||
POOLER_RESOLUTION: 14 | ||
POOLER_SAMPLING_RATIO: 2 | ||
RESOLUTION: 28 | ||
SHARE_BOX_FEATURE_EXTRACTOR: False | ||
MASK_ON: True | ||
DATASETS: | ||
TRAIN: ("coco_2014_train", "coco_2014_valminusminival") | ||
TEST: ("coco_2014_minival",) | ||
DATALOADER: | ||
SIZE_DIVISIBILITY: 32 | ||
SOLVER: | ||
BASE_LR: 0.02 | ||
WEIGHT_DECAY: 0.0001 | ||
STEPS: (60000, 80000) | ||
MAX_ITER: 90000 | ||
TEST: | ||
BBOX_AUG: | ||
ENABLED: True | ||
H_FLIP: True | ||
SCALES: (400, 500, 600, 700, 900, 1000, 1100, 1200) | ||
MAX_SIZE: 2000 | ||
SCALE_H_FLIP: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import torch | ||
import torchvision.transforms as TT | ||
|
||
from maskrcnn_benchmark.config import cfg | ||
from maskrcnn_benchmark.data import transforms as T | ||
from maskrcnn_benchmark.structures.image_list import to_image_list | ||
from maskrcnn_benchmark.structures.bounding_box import BoxList | ||
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import make_roi_box_post_processor | ||
|
||
|
||
def im_detect_bbox_aug(model, images, device): | ||
# Collect detections computed under different transformations | ||
boxlists_ts = [] | ||
for _ in range(len(images)): | ||
boxlists_ts.append([]) | ||
|
||
def add_preds_t(boxlists_t): | ||
for i, boxlist_t in enumerate(boxlists_t): | ||
if len(boxlists_ts[i]) == 0: | ||
# The first one is identity transform, no need to resize the boxlist | ||
boxlists_ts[i].append(boxlist_t) | ||
else: | ||
# Resize the boxlist as the first one | ||
boxlists_ts[i].append(boxlist_t.resize(boxlists_ts[i][0].size)) | ||
|
||
# Compute detections for the original image (identity transform) | ||
boxlists_i = im_detect_bbox( | ||
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device | ||
) | ||
add_preds_t(boxlists_i) | ||
|
||
# Perform detection on the horizontally flipped image | ||
if cfg.TEST.BBOX_AUG.H_FLIP: | ||
boxlists_hf = im_detect_bbox_hflip( | ||
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device | ||
) | ||
add_preds_t(boxlists_hf) | ||
|
||
# Compute detections at different scales | ||
for scale in cfg.TEST.BBOX_AUG.SCALES: | ||
max_size = cfg.TEST.BBOX_AUG.MAX_SIZE | ||
boxlists_scl = im_detect_bbox_scale( | ||
model, images, scale, max_size, device | ||
) | ||
add_preds_t(boxlists_scl) | ||
|
||
if cfg.TEST.BBOX_AUG.SCALE_H_FLIP: | ||
boxlists_scl_hf = im_detect_bbox_scale( | ||
model, images, scale, max_size, device, hflip=True | ||
) | ||
add_preds_t(boxlists_scl_hf) | ||
|
||
# Merge boxlists detected by different bbox aug params | ||
boxlists = [] | ||
for i, boxlist_ts in enumerate(boxlists_ts): | ||
bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts]) | ||
scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts]) | ||
boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode) | ||
boxlist.add_field('scores', scores) | ||
boxlists.append(boxlist) | ||
|
||
# Apply NMS and limit the final detections | ||
results = [] | ||
post_processor = make_roi_box_post_processor(cfg) | ||
for boxlist in boxlists: | ||
results.append(post_processor.filter_results(boxlist, cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES)) | ||
|
||
return results | ||
|
||
|
||
def im_detect_bbox(model, images, target_scale, target_max_size, device): | ||
""" | ||
Performs bbox detection on the original image. | ||
""" | ||
transform = TT.Compose([ | ||
T.Resize(target_scale, target_max_size), | ||
TT.ToTensor(), | ||
T.Normalize( | ||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255 | ||
) | ||
]) | ||
images = [transform(image) for image in images] | ||
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY) | ||
return model(images.to(device)) | ||
|
||
|
||
def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device): | ||
""" | ||
Performs bbox detection on the horizontally flipped image. | ||
Function signature is the same as for im_detect_bbox. | ||
""" | ||
transform = TT.Compose([ | ||
T.Resize(target_scale, target_max_size), | ||
TT.RandomHorizontalFlip(1.0), | ||
TT.ToTensor(), | ||
T.Normalize( | ||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255 | ||
) | ||
]) | ||
images = [transform(image) for image in images] | ||
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY) | ||
boxlists = model(images.to(device)) | ||
|
||
# Invert the detections computed on the flipped image | ||
boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists] | ||
return boxlists_inv | ||
|
||
|
||
def im_detect_bbox_scale(model, images, target_scale, target_max_size, device, hflip=False): | ||
""" | ||
Computes bbox detections at the given scale. | ||
Returns predictions in the scaled image space. | ||
""" | ||
if hflip: | ||
boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device) | ||
else: | ||
boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device) | ||
return boxlists_scl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters