Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Support Binary Mask with transparent SegmentationMask interface #473

Merged
merged 15 commits into from
Apr 9, 2019
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __getitem__(self, idx):
target.add_field("labels", classes)

masks = [obj["segmentation"] for obj in anno]
masks = SegmentationMask(masks, img.size)
masks = SegmentationMask(masks, img.size, mode='poly')
botcs marked this conversation as resolved.
Show resolved Hide resolved
target.add_field("masks", masks)

if anno and "keypoints" in anno[0]:
Expand Down
10 changes: 4 additions & 6 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size):
assert segmentation_masks.size == proposals.size, "{}, {}".format(
segmentation_masks, proposals
)
# TODO put the proposals on the CPU, as the representation for the
botcs marked this conversation as resolved.
Show resolved Hide resolved
# masks is not efficient GPU-wise (possibly several small tensors for
# representing a single instance mask)

# FIXME: CPU computation bottleneck, this should be parallelized
proposals = proposals.bbox.to(torch.device("cpu"))
for segmentation_mask, proposal in zip(segmentation_masks, proposals):
# crop the masks, resize them to the desired resolution and
# then convert them to the tensor representation,
# instead of the list representation that was used
# then convert them to the tensor representation.
cropped_mask = segmentation_mask.crop(proposal)
scaled_mask = cropped_mask.resize((M, M))
mask = scaled_mask.convert(mode="mask")
mask = scaled_mask.get_mask_tensor()
botcs marked this conversation as resolved.
Show resolved Hide resolved
masks.append(mask)
if len(masks) == 0:
return torch.empty(0, dtype=torch.float32, device=device)
Expand Down
Loading