Skip to content

Commit

Permalink
Adding ciou and diou support in _box_loss() (#5984)
Browse files Browse the repository at this point in the history
* Adding ciou and diou support in `_box_loss()`

* Fix linter

* Addressing comments for nits
  • Loading branch information
datumbox authored May 11, 2022
1 parent 3ec4b94 commit 64b1e27
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss
from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss


class BalancedPositiveNegativeSampler:
Expand Down Expand Up @@ -518,7 +518,7 @@ def _box_loss(
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")

if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
Expand All @@ -527,7 +527,12 @@ def _box_loss(
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else: # giou
else:
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
if type == "ciou":
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
if type == "diou":
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
# otherwise giou
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

0 comments on commit 64b1e27

Please sign in to comment.