Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Add an option to postprocess masks during inference (#180)
Browse files Browse the repository at this point in the history
* Add an option to postprocess masks during inference

* Fix COCO evaluation to resize masks ony if needed.

* Fix casting

* Fix minor issues in paste_mask_in_image

* Cast mask to uint8

* Make Masker batch compatible

* Remove warnings and stylistic changes
  • Loading branch information
fmassa authored Nov 19, 2018
1 parent 13555fc commit ca9531b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@
_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
# Whether or not resize and translate masks to the input image.
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5

# ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
Expand Down
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def prepare_for_coco_segmentation(predictions, dataset):
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")

# t = time.time()
masks = masker(masks, prediction)
# Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')

Expand Down Expand Up @@ -426,6 +430,6 @@ def inference(
check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth"))

return results, coco_results, predictions

47 changes: 31 additions & 16 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
import torch
from PIL import Image
from torch import nn
import torch.nn.functional as F

from maskrcnn_benchmark.structures.bounding_box import BoxList

Expand Down Expand Up @@ -44,12 +44,12 @@ def forward(self, x, boxes):
index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None]

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

boxes_per_image = [len(box) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

results = []
for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
Expand Down Expand Up @@ -119,25 +119,28 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.numpy().astype(np.int32)
box = box.to(dtype=torch.int32)

TO_REMOVE = 1
w = box[2] - box[0] + TO_REMOVE
h = box[3] - box[1] + TO_REMOVE
w = max(w, 1)
h = max(h, 1)

mask = Image.fromarray(mask.cpu().numpy())
mask = mask.resize((w, h), resample=Image.BILINEAR)
mask = np.array(mask, copy=False)
# Set shape to [batchxCxHxW]
mask = mask.expand((1, 1, -1, -1))

# Resize mask
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]

if thresh >= 0:
mask = np.array(mask > thresh, dtype=np.uint8)
mask = torch.from_numpy(mask)
mask = mask > thresh
else:
# for visualization and debugging, we also
# allow it to return an unmodified mask
mask = torch.from_numpy(mask * 255).to(torch.uint8)
mask = (mask * 255).to(torch.uint8)

im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
Expand Down Expand Up @@ -175,15 +178,27 @@ def forward_single_image(self, masks, boxes):
return res

def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList):
boxes = [boxes]
assert len(boxes) == 1, "Only single image batch supported"
result = self.forward_single_image(masks, boxes[0])
return result

# Make some sanity check
assert len(boxes) == len(masks), "Masks and boxes should have the same length."

# TODO: Is this JIT compatible?
# If not we should make it compatible.
results = []
for mask, box in zip(masks, boxes):
assert mask.shape[0] == len(box), "Number of objects should be the same."
result = self.forward_single_image(mask, box)
results.append(result)
return results


def make_roi_mask_post_processor(cfg):
masker = None
if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS:
mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD
masker = Masker(threshold=mask_threshold, padding=1)
else:
masker = None
mask_post_processor = MaskPostProcessor(masker)
return mask_post_processor

0 comments on commit ca9531b

Please sign in to comment.