Skip to content

Commit

Permalink
[*] bcnet maskrcnn: modelling, engine
Browse files Browse the repository at this point in the history
  • Loading branch information
trqminh committed Mar 2, 2022
1 parent 181a4da commit 33c3984
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 14 deletions.
2 changes: 1 addition & 1 deletion detectron2/engine/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def run_step(self):
"""
If you want to do something with the losses, you can wrap the model.
"""
loss_dict = self.model(data)
loss_dict = self.model(data, self.iter)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
Expand Down
2 changes: 1 addition & 1 deletion detectron2/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ROI_MASK_HEAD_REGISTRY,
ROIHeads,
StandardROIHeads,
BaseMaskRCNNHead,
# BaseMaskRCNNHead,
BaseKeypointRCNNHead,
FastRCNNOutputLayers,
build_box_head,
Expand Down
4 changes: 2 additions & 2 deletions detectron2/modeling/meta_arch/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def visualize_training(self, batched_inputs, proposals):
storage.put_image(vis_name, vis_img)
break # only visualize one image in a batch

def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]], c_iter=None):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Expand Down Expand Up @@ -160,7 +160,7 @@ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}

_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances, c_iter)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
Expand Down
2 changes: 1 addition & 1 deletion detectron2/modeling/roi_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .mask_head import (
ROI_MASK_HEAD_REGISTRY,
build_mask_head,
BaseMaskRCNNHead,
# BaseMaskRCNNHead,
MaskRCNNConvUpsampleHead,
)
from .roi_heads import (
Expand Down
17 changes: 12 additions & 5 deletions detectron2/modeling/roi_heads/mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_
else:
indices = torch.arange(total_num_masks)
gt_classes = cat(gt_classes, dim=0)
pred_mask_logits = pred_mask_logits[indices, gt_classes]
pred_mask_logits_gt = pred_mask_logits[indices, gt_classes]
pred_bo_mask_logits = pred_mask_bo_logits[indices, gt_classes]
pred_boundary_logits_bo = pred_boundary_logits_bo[indices, gt_classes]
pred_boundary_logits = pred_boundary_logits[indices, gt_classes]
if pred_a_mask_logits:
pred_a_mask_logits = pred_a_mask_logits[indices, gt_classes]
pred_a_mask_logits_gt = pred_a_mask_logits[indices, gt_classes]


if gt_masks.dtype == torch.bool:
gt_masks_bool = gt_masks
else:
Expand Down Expand Up @@ -222,10 +226,10 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_
)

if use_i_mask:
bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())(
bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())(
pred_boundary_logits.unsqueeze(1), gt_i_boundary.to(dtype=torch.float32))
else:
bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())(
bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())(
pred_boundary_logits.unsqueeze(1), gt_boundary.to(dtype=torch.float32))

if new_gt_bo_masks.shape[0] > 0:
Expand All @@ -236,7 +240,7 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_
bo_mask_loss = torch.tensor(0.0).cuda(mask_loss.get_device())

if new_gt_bo_bounds.shape[0] > 0:
bo_bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())(
bo_bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())(
new_pred_bo_bounds_logits.unsqueeze(1), new_gt_bo_bounds.to(dtype=torch.float32))
else:
bo_bound_loss = torch.tensor(0.0).cuda(mask_loss.get_device())
Expand Down Expand Up @@ -293,6 +297,9 @@ def mask_rcnn_inference(pred_mask_logits, bo_mask_logits, bound_logits, bo_bound
class_pred = cat([i.pred_classes for i in pred_instances])
indices = torch.arange(num_masks, device=class_pred.device)
mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
bound_probs_pred = bound_logits.sigmoid()
bo_mask_probs_pred = bo_mask_logits.sigmoid()
bo_bound_probs_pred = bo_bound_logits.sigmoid()
# mask_probs_pred.shape: (B, 1, Hmask, Wmask)

num_boxes_per_image = [len(i) for i in pred_instances]
Expand Down
32 changes: 28 additions & 4 deletions detectron2/modeling/roi_heads/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .box_head import build_box_head
from .fast_rcnn import FastRCNNOutputLayers
from .keypoint_head import build_keypoint_head
from .mask_head import build_mask_head
from .mask_head import build_mask_head, mask_rcnn_inference, mask_rcnn_loss

ROI_HEADS_REGISTRY = Registry("ROI_HEADS")
ROI_HEADS_REGISTRY.__doc__ = """
Expand Down Expand Up @@ -725,10 +725,12 @@ def forward(
features: Dict[str, torch.Tensor],
proposals: List[Instances],
targets: Optional[List[Instances]] = None,
c_iter=None
) -> Tuple[List[Instances], Dict[str, torch.Tensor]]:
"""
See :class:`ROIHeads.forward`.
"""
# This forward
del images
if self.training:
assert targets, "'targets' argument is required during training"
Expand All @@ -740,7 +742,22 @@ def forward(
# Usually the original proposals used by the box head are used by the mask, keypoint
# heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
# predicted by the box head.
losses.update(self._forward_mask(features, proposals))

mask_head_results, instances = self._forward_mask(features, proposals, c_iter)
mask_logits, boundary, bo_masks, bo_bound, mask_head_features = mask_head_results

loss_mask, loss_mask_bo, loss_boundary, loss_boundary_bo, loss_a_mask, loss_justify \
= mask_rcnn_loss(mask_logits, boundary, instances, bo_masks, bo_bound)

losses.update({
"loss_mask": loss_mask,
"loss_mask_bo": loss_mask_bo * 0.25,
"loss_boundary_bo": loss_boundary_bo * 0.5,
"loss_boundary": loss_boundary * 0.5,
"loss_a_mask": loss_a_mask,
"loss_justify": loss_justify
})
# losses.update(self._forward_mask(features, proposals))
losses.update(self._forward_keypoint(features, proposals))
return proposals, losses
else:
Expand Down Expand Up @@ -815,7 +832,7 @@ def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instan
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
return pred_instances

def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances]):
def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances], c_iter=None):
"""
Forward logic of the mask prediction branch.
Expand Down Expand Up @@ -843,7 +860,14 @@ def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Insta
features = self.mask_pooler(features, boxes)
else:
features = {f: features[f] for f in self.mask_in_features}
return self.mask_head(features, instances)

if self.training:
return self.mask_head(features), instances
else:
mask_logits, boundary, bo_masks, bo_bound, mask_head_features = self.mask_head(features)
mask_rcnn_inference(mask_logits, bo_masks, boundary, bo_bound, instances)
return instances


def _forward_keypoint(self, features: Dict[str, torch.Tensor], instances: List[Instances]):
"""
Expand Down

0 comments on commit 33c3984

Please sign in to comment.