Skip to content

Commit

Permalink
Return also classes for MAP metric (#1419)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2023
1 parent 7b5f8eb commit ac64e63
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
ls -lh $(HF_CACHE_DIR) # show what was restored...
displayName: 'Show HF cache'
- bash: python -m pytest torchmetrics --cov=torchmetrics --timeout=120 --durations=50
- bash: python -m pytest torchmetrics --cov=torchmetrics --timeout=150 --durations=50
workingDirectory: src
displayName: 'DocTesting'

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328))

- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))

### Changed

Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __delattr__(self, key: str) -> None:
class MAPMetricResults(BaseMetricResults):
"""Class to wrap the final mAP results."""

__slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large")
__slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large", "classes")


class MARMetricResults(BaseMetricResults):
Expand Down Expand Up @@ -248,6 +248,7 @@ class MeanAveragePrecision(Metric):
- map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds)
- map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
- mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled)
- classes (:class:`~torch.Tensor`)
For an example on how to use this metric check the `torchmetrics examples
<https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py>`_
Expand Down Expand Up @@ -332,7 +333,8 @@ class MeanAveragePrecision(Metric):
>>> metric.update(preds, target)
>>> from pprint import pprint
>>> pprint(metric.compute())
{'map': tensor(0.6000),
{'classes': tensor(0, dtype=torch.int32),
'map': tensor(0.6000),
'map_50': tensor(1.),
'map_75': tensor(1.),
'map_large': tensor(0.6000),
Expand Down Expand Up @@ -923,5 +925,5 @@ def compute(self) -> dict:
metrics.update(mar_val)
metrics.map_per_class = map_per_class_values
metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values

metrics.classes = torch.tensor(classes, dtype=torch.int)
return metrics
2 changes: 2 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _compare_fn(preds, target) -> dict:
"mar_large": Tensor([0.633]),
"map_per_class": Tensor([0.725, 0.800, 0.454, -1.000, 0.650, 0.556]),
"mar_100_per_class": Tensor([0.780, 0.800, 0.450, -1.000, 0.650, 0.580]),
"classes": Tensor([0, 1, 2, 3, 4, 49]),
}


Expand Down Expand Up @@ -317,6 +318,7 @@ def _compare_fn_segm(preds, target) -> dict:
"mar_large": Tensor([0.35]),
"map_per_class": Tensor([0.4039604, -1.0, 0.3]),
"mar_100_per_class": Tensor([0.4, -1.0, 0.3]),
"classes": Tensor([2, 3, 4]),
}


Expand Down

0 comments on commit ac64e63

Please sign in to comment.