Skip to content

Commit

Permalink
Implement pre and postfix for Classwise Wrapper (#1866)
Browse files Browse the repository at this point in the history
Co-authored-by: SkafteNicki <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Jul 4, 2023
1 parent 15cf3a4 commit d9f7c35
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `prefix` and `postfix` arguments to `ClasswiseWrapper` ([#1866](https://github.com/Lightning-AI/torchmetrics/pull/1866))


- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792), [#1872](https://github.com/Lightning-AI/torchmetrics/pull/1872))


Expand Down
67 changes: 59 additions & 8 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ class ClasswiseWrapper(Metric):
metric: base metric that should be wrapped. It is assumed that the metric outputs a single
tensor that is split along the first dimension.
labels: list of strings indicating the different classes.
prefix: string that is prepended to the metric names.
postfix: string that is appended to the metric names.
Example::
Basic example where the ouput of a metric is unwrapped into a dictionary with the class index as keys:
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
Expand All @@ -47,7 +51,29 @@ class ClasswiseWrapper(Metric):
'multiclassaccuracy_1': tensor(0.7500),
'multiclassaccuracy_2': tensor(0.)}
Example (labels as list of strings):
Example::
Using custom name via prefix and postfix:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-")
>>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc")
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric_pre(preds, target) # doctest: +NORMALIZE_WHITESPACE
{'acc-0': tensor(0.5000),
'acc-1': tensor(0.7500),
'acc-2': tensor(0.)}
>>> metric_post(preds, target) # doctest: +NORMALIZE_WHITESPACE
{'0-acc': tensor(0.5000),
'1-acc': tensor(0.7500),
'2-acc': tensor(0.)}
Example::
Providing labels as a list of strings:
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
Expand All @@ -61,7 +87,10 @@ class ClasswiseWrapper(Metric):
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}
Example (in metric collection):
Example::
Classwise can also be used in combination with :class:`~torchmetrics.MetricCollection`. In this case, everything
will be flattened into a single dictionary:
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall
Expand All @@ -81,21 +110,43 @@ class ClasswiseWrapper(Metric):
'multiclassrecall_dog': tensor(0.4000)}
"""

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
def __init__(
self,
metric: Metric,
labels: Optional[List[str]] = None,
prefix: Optional[str] = None,
postfix: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(metric, Metric):
raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}")
self.metric = metric

if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)):
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}")
self.metric = metric
self.labels = labels

if prefix is not None and not isinstance(prefix, str):
raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}")
self.prefix = prefix

if postfix is not None and not isinstance(postfix, str):
raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}")
self.postfix = postfix

self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
name = self.metric.__class__.__name__.lower()
# Will set the class name as prefix if neither prefix nor postfix is given
if not self.prefix and not self.postfix:
prefix = f"{self.metric.__class__.__name__.lower()}_"
postfix = ""
else:
prefix = self.prefix or ""
postfix = self.postfix or ""
if self.labels is None:
return {f"{name}_{i}": val for i, val in enumerate(x)}
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)}
return {f"{prefix}{i}{postfix}": val for i, val in enumerate(x)}
return {f"{prefix}{lab}{postfix}": val for lab, val in zip(self.labels, x)}

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Calculate on batch and accumulate to global state."""
Expand Down
38 changes: 38 additions & 0 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def test_raises_error_on_wrong_input():
with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), "hest")

with pytest.raises(ValueError, match="Expected argument `prefix` to either be `None` or a string.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), prefix=1)

with pytest.raises(ValueError, match="Expected argument `postfix` to either be `None` or a string.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), postfix=1)


def test_output_no_labels():
"""Test that wrapper works with no label input."""
Expand Down Expand Up @@ -54,6 +60,38 @@ def test_output_with_labels():
assert val[f"multiclassaccuracy_{lab}"] == val_base[i]


def test_output_with_prefix():
"""Test that wrapper works with prefix."""
base = MulticlassAccuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="pre_")
for _ in range(2):
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"pre_{i}" in val
assert val[f"pre_{i}"] == val_base[i]


def test_output_with_postfix():
"""Test that wrapper works with postfix."""
base = MulticlassAccuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="_post")
for _ in range(2):
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"{i}_post" in val
assert val[f"{i}_post"] == val_base[i]


@pytest.mark.parametrize("prefix", [None, "pre_"])
@pytest.mark.parametrize("postfix", [None, "_post"])
def test_using_metriccollection(prefix, postfix):
Expand Down

0 comments on commit d9f7c35

Please sign in to comment.