From 64b1e279d7963c923bd453de07589a25cb6e8d03 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 11 May 2022 10:48:51 +0100 Subject: [PATCH] Adding ciou and diou support in `_box_loss()` (#5984) * Adding ciou and diou support in `_box_loss()` * Fix linter * Addressing comments for nits --- torchvision/models/detection/_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index f4c426691c0..d808ecffed3 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn from torch.nn import functional as F -from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss +from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss class BalancedPositiveNegativeSampler: @@ -518,7 +518,7 @@ def _box_loss( bbox_regression_per_image: Tensor, cnf: Optional[Dict[str, float]] = None, ) -> Tensor: - torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}") + torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}") if type == "l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) @@ -527,7 +527,12 @@ def _box_loss( target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0 return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta) - else: # giou + else: bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image) eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 + if type == "ciou": + return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + if type == "diou": + return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + # otherwise giou return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)