Skip to content

Commit

Permalink
Fix corner case when single empty pred is provided to detection IOU
Browse files Browse the repository at this point in the history
… metric (#2780)
  • Loading branch information
SkafteNicki authored Oct 15, 2024
1 parent 801cec8 commit a8cbae7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed corner case in `Iou` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780))


## [1.4.3] - 2024-10-10
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def compute(self) -> dict:
"""Computes IoU based on inputs passed in to ``update`` previously."""
score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean()
results: Dict[str, Tensor] = {f"{self._iou_type}": score}

if torch.isnan(score): # if no valid boxes are found
results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device)
if self.class_metrics:
gt_labels = dim_zero_cat(self.groundtruth_labels)
classes = gt_labels.unique().tolist() if len(gt_labels) > 0 else []
Expand Down
29 changes: 28 additions & 1 deletion tests/unittests/detection/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla
base_name = {tv_ciou: "ciou", tv_diou: "diou", tv_giou: "giou", tv_iou: "iou"}[base_fn]

result = {f"{base_name}": score.cpu()}
if torch.isnan(score):
result.update({f"{base_name}": torch.tensor(0.0)})
if class_metrics:
for cl in torch.cat(classes).unique().tolist():
class_score, numel = 0, 0
Expand All @@ -71,7 +73,6 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla
class_score += masked_s[masked_s != -1].sum()
numel += masked_s[masked_s != -1].numel()
result.update({f"{base_name}/cl_{cl}": class_score.cpu() / numel})

return result


Expand Down Expand Up @@ -328,6 +329,32 @@ def test_functional_error_on_wrong_input_shape(self, class_metric, functional_me
with pytest.raises(ValueError, match="Expected target to be of shape.*"):
functional_metric(torch.randn(25, 4), torch.randn(25, 25))

def test_corner_case_only_one_empty_prediction(self, class_metric, functional_metric, reference_metric):
"""Test that the metric does not crash when there is only one empty prediction."""
target = [
{
"boxes": torch.tensor([
[8.0000, 70.0000, 76.0000, 110.0000],
[247.0000, 131.0000, 315.0000, 175.0000],
[361.0000, 177.0000, 395.0000, 203.0000],
]),
"labels": torch.tensor([0, 0, 0]),
}
]
preds = [
{
"boxes": torch.empty(size=(0, 4)),
"labels": torch.tensor([], dtype=torch.int64),
"scores": torch.tensor([]),
}
]

metric = class_metric()
metric.update(preds, target)
res = metric.compute()
for val in res.values():
assert val == torch.tensor(0.0)


def test_corner_case():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""
Expand Down

0 comments on commit a8cbae7

Please sign in to comment.