Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Make Masker batch compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Nov 8, 2018
1 parent 19b84bb commit 8d00e40
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def prepare_for_coco_segmentation(predictions, dataset):
# t = time.time()
# Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks, prediction)
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')

Expand Down
22 changes: 15 additions & 7 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def forward(self, x, boxes):
index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None]

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

boxes_per_image = [len(box) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)

if self.masker:
mask_prob = self.masker(mask_prob, boxes)

results = []
for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
Expand Down Expand Up @@ -178,12 +178,20 @@ def forward_single_image(self, masks, boxes):
return res

def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList):
boxes = [boxes]
assert len(boxes) == 1, "Only single image batch supported"
result = self.forward_single_image(masks, boxes[0])
return result

# Make some sanity check
assert len(boxes) == len(masks), "Masks and boxes should have the same length."

# TODO: Is this JIT compatible?
# If not we should make it compatible.
results = []
for mask, box in zip(masks, boxes):
assert mask.shape[0] == box.bbox.shape[0], "Number of objects should be the same."
result = self.forward_single_image(mask, box)
results.append(result)
return results


def make_roi_mask_post_processor(cfg):
Expand Down

0 comments on commit 8d00e40

Please sign in to comment.