Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add an option to postprocess masks during inference #180

Merged
merged 8 commits into from
Nov 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@
_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
# Whether or not resize and translate masks to the input image.
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5

# ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
Expand Down
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def prepare_for_coco_segmentation(predictions, dataset):
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")

# t = time.time()
masks = masker(masks, prediction)
# Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')

Expand Down Expand Up @@ -426,6 +430,6 @@ def inference(
check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth"))

return results, coco_results, predictions

47 changes: 31 additions & 16 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
import torch
from PIL import Image
from torch import nn
import torch.nn.functional as F

from maskrcnn_benchmark.structures.bounding_box import BoxList

Expand Down Expand Up @@ -44,12 +44,12 @@ def forward(self, x, boxes):
index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None]

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

boxes_per_image = [len(box) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

results = []
for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
Expand Down Expand Up @@ -119,25 +119,28 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.numpy().astype(np.int32)
box = box.to(dtype=torch.int32)

TO_REMOVE = 1
w = box[2] - box[0] + TO_REMOVE
h = box[3] - box[1] + TO_REMOVE
w = max(w, 1)
h = max(h, 1)

mask = Image.fromarray(mask.cpu().numpy())
mask = mask.resize((w, h), resample=Image.BILINEAR)
mask = np.array(mask, copy=False)
# Set shape to [batchxCxHxW]
mask = mask.expand((1, 1, -1, -1))

# Resize mask
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]

if thresh >= 0:
mask = np.array(mask > thresh, dtype=np.uint8)
mask = torch.from_numpy(mask)
mask = mask > thresh
else:
# for visualization and debugging, we also
# allow it to return an unmodified mask
mask = torch.from_numpy(mask * 255).to(torch.uint8)
mask = (mask * 255).to(torch.uint8)

im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
Expand Down Expand Up @@ -175,15 +178,27 @@ def forward_single_image(self, masks, boxes):
return res

def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList):
boxes = [boxes]
assert len(boxes) == 1, "Only single image batch supported"
result = self.forward_single_image(masks, boxes[0])
return result

# Make some sanity check
assert len(boxes) == len(masks), "Masks and boxes should have the same length."

# TODO: Is this JIT compatible?
# If not we should make it compatible.
results = []
for mask, box in zip(masks, boxes):
assert mask.shape[0] == len(box), "Number of objects should be the same."
result = self.forward_single_image(mask, box)
results.append(result)
return results


def make_roi_mask_post_processor(cfg):
masker = None
if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS:
mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD
masker = Masker(threshold=mask_threshold, padding=1)
else:
masker = None
mask_post_processor = MaskPostProcessor(masker)
return mask_post_processor