From 33c3984018da9bce003ff127306beddb3a8ae6b4 Mon Sep 17 00:00:00 2001 From: trqminh Date: Wed, 2 Mar 2022 14:16:06 -0600 Subject: [PATCH] [*] bcnet maskrcnn: modelling, engine --- detectron2/engine/train_loop.py | 2 +- detectron2/modeling/__init__.py | 2 +- detectron2/modeling/meta_arch/rcnn.py | 4 +-- detectron2/modeling/roi_heads/__init__.py | 2 +- detectron2/modeling/roi_heads/mask_head.py | 17 ++++++++---- detectron2/modeling/roi_heads/roi_heads.py | 32 +++++++++++++++++++--- 6 files changed, 45 insertions(+), 14 deletions(-) diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index c4a86b5..b21416c 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -270,7 +270,7 @@ def run_step(self): """ If you want to do something with the losses, you can wrap the model. """ - loss_dict = self.model(data) + loss_dict = self.model(data, self.iter) if isinstance(loss_dict, torch.Tensor): losses = loss_dict loss_dict = {"total_loss": loss_dict} diff --git a/detectron2/modeling/__init__.py b/detectron2/modeling/__init__.py index 576493d..eceee50 100644 --- a/detectron2/modeling/__init__.py +++ b/detectron2/modeling/__init__.py @@ -38,7 +38,7 @@ ROI_MASK_HEAD_REGISTRY, ROIHeads, StandardROIHeads, - BaseMaskRCNNHead, + # BaseMaskRCNNHead, BaseKeypointRCNNHead, FastRCNNOutputLayers, build_box_head, diff --git a/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py index 7b45363..6962165 100644 --- a/detectron2/modeling/meta_arch/rcnn.py +++ b/detectron2/modeling/meta_arch/rcnn.py @@ -119,7 +119,7 @@ def visualize_training(self, batched_inputs, proposals): storage.put_image(vis_name, vis_img) break # only visualize one image in a batch - def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]], c_iter=None): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . @@ -160,7 +160,7 @@ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} - _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) + _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, c_iter) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: diff --git a/detectron2/modeling/roi_heads/__init__.py b/detectron2/modeling/roi_heads/__init__.py index d13e9c5..5a81648 100644 --- a/detectron2/modeling/roi_heads/__init__.py +++ b/detectron2/modeling/roi_heads/__init__.py @@ -9,7 +9,7 @@ from .mask_head import ( ROI_MASK_HEAD_REGISTRY, build_mask_head, - BaseMaskRCNNHead, + # BaseMaskRCNNHead, MaskRCNNConvUpsampleHead, ) from .roi_heads import ( diff --git a/detectron2/modeling/roi_heads/mask_head.py b/detectron2/modeling/roi_heads/mask_head.py index fa45a1a..fcd32e9 100644 --- a/detectron2/modeling/roi_heads/mask_head.py +++ b/detectron2/modeling/roi_heads/mask_head.py @@ -145,10 +145,14 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_ else: indices = torch.arange(total_num_masks) gt_classes = cat(gt_classes, dim=0) - pred_mask_logits = pred_mask_logits[indices, gt_classes] + pred_mask_logits_gt = pred_mask_logits[indices, gt_classes] + pred_bo_mask_logits = pred_mask_bo_logits[indices, gt_classes] + pred_boundary_logits_bo = pred_boundary_logits_bo[indices, gt_classes] + pred_boundary_logits = pred_boundary_logits[indices, gt_classes] if pred_a_mask_logits: - pred_a_mask_logits = pred_a_mask_logits[indices, gt_classes] + pred_a_mask_logits_gt = pred_a_mask_logits[indices, gt_classes] + if gt_masks.dtype == torch.bool: gt_masks_bool = gt_masks else: @@ -222,10 +226,10 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_ ) if use_i_mask: - bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())( + bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())( pred_boundary_logits.unsqueeze(1), gt_i_boundary.to(dtype=torch.float32)) else: - bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())( + bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())( pred_boundary_logits.unsqueeze(1), gt_boundary.to(dtype=torch.float32)) if new_gt_bo_masks.shape[0] > 0: @@ -236,7 +240,7 @@ def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances, pred_mask_ bo_mask_loss = torch.tensor(0.0).cuda(mask_loss.get_device()) if new_gt_bo_bounds.shape[0] > 0: - bo_bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())( + bo_bound_loss = L.JointLoss(L.BalancedBCEWithLogitsLoss(), L.BalancedBCEWithLogitsLoss())( new_pred_bo_bounds_logits.unsqueeze(1), new_gt_bo_bounds.to(dtype=torch.float32)) else: bo_bound_loss = torch.tensor(0.0).cuda(mask_loss.get_device()) @@ -293,6 +297,9 @@ def mask_rcnn_inference(pred_mask_logits, bo_mask_logits, bound_logits, bo_bound class_pred = cat([i.pred_classes for i in pred_instances]) indices = torch.arange(num_masks, device=class_pred.device) mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid() + bound_probs_pred = bound_logits.sigmoid() + bo_mask_probs_pred = bo_mask_logits.sigmoid() + bo_bound_probs_pred = bo_bound_logits.sigmoid() # mask_probs_pred.shape: (B, 1, Hmask, Wmask) num_boxes_per_image = [len(i) for i in pred_instances] diff --git a/detectron2/modeling/roi_heads/roi_heads.py b/detectron2/modeling/roi_heads/roi_heads.py index 13dd57a..5292f87 100644 --- a/detectron2/modeling/roi_heads/roi_heads.py +++ b/detectron2/modeling/roi_heads/roi_heads.py @@ -20,7 +20,7 @@ from .box_head import build_box_head from .fast_rcnn import FastRCNNOutputLayers from .keypoint_head import build_keypoint_head -from .mask_head import build_mask_head +from .mask_head import build_mask_head, mask_rcnn_inference, mask_rcnn_loss ROI_HEADS_REGISTRY = Registry("ROI_HEADS") ROI_HEADS_REGISTRY.__doc__ = """ @@ -725,10 +725,12 @@ def forward( features: Dict[str, torch.Tensor], proposals: List[Instances], targets: Optional[List[Instances]] = None, + c_iter=None ) -> Tuple[List[Instances], Dict[str, torch.Tensor]]: """ See :class:`ROIHeads.forward`. """ + # This forward del images if self.training: assert targets, "'targets' argument is required during training" @@ -740,7 +742,22 @@ def forward( # Usually the original proposals used by the box head are used by the mask, keypoint # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes # predicted by the box head. - losses.update(self._forward_mask(features, proposals)) + + mask_head_results, instances = self._forward_mask(features, proposals, c_iter) + mask_logits, boundary, bo_masks, bo_bound, mask_head_features = mask_head_results + + loss_mask, loss_mask_bo, loss_boundary, loss_boundary_bo, loss_a_mask, loss_justify \ + = mask_rcnn_loss(mask_logits, boundary, instances, bo_masks, bo_bound) + + losses.update({ + "loss_mask": loss_mask, + "loss_mask_bo": loss_mask_bo * 0.25, + "loss_boundary_bo": loss_boundary_bo * 0.5, + "loss_boundary": loss_boundary * 0.5, + "loss_a_mask": loss_a_mask, + "loss_justify": loss_justify + }) + # losses.update(self._forward_mask(features, proposals)) losses.update(self._forward_keypoint(features, proposals)) return proposals, losses else: @@ -815,7 +832,7 @@ def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instan pred_instances, _ = self.box_predictor.inference(predictions, proposals) return pred_instances - def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances], c_iter=None): """ Forward logic of the mask prediction branch. @@ -843,7 +860,14 @@ def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Insta features = self.mask_pooler(features, boxes) else: features = {f: features[f] for f in self.mask_in_features} - return self.mask_head(features, instances) + + if self.training: + return self.mask_head(features), instances + else: + mask_logits, boundary, bo_masks, bo_bound, mask_head_features = self.mask_head(features) + mask_rcnn_inference(mask_logits, bo_masks, boundary, bo_bound, instances) + return instances + def _forward_keypoint(self, features: Dict[str, torch.Tensor], instances: List[Instances]): """