Skip to content

Commit

Permalink
Fix incorrect caching of MetricCollection (#2571)
Browse files Browse the repository at this point in the history
* Fix incorrect caching of MetricCollection

Also update tests to correctly test for #2211.
Also add return deepcopy values of compute

* Update Changelog

* fix issues with deepcopy by using .clone

---------

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Jun 1, 2024
1 parent 599991d commit 6a23b38
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2571](https://github.com/Lightning-AI/torchmetrics/pull/2571))


- Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548))


Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def __init__(

self.add_metrics(metrics, *additional_metrics)

@property
def metric_state(self) -> Dict[str, Dict[str, Any]]:
"""Get the current state of the metric."""
return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)}

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Call forward for each metric sequentially.
Expand All @@ -206,6 +211,11 @@ def update(self, *args: Any, **kwargs: Any) -> None:
"""
# Use compute groups if already initialized and checked
if self._groups_checked:
# Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new
# compute calls to recompute
for k in self.keys(keep_base=True):
mi = getattr(self, str(k))
mi._computed = None
for cg in self._groups.values():
# only update the first member
m0 = getattr(self, cg[0])
Expand Down Expand Up @@ -304,7 +314,6 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
should_unsync=self._should_unsync,
):
value = _squeeze_if_scalar(compute(*args, **kwargs))
# clone tensor to avoid in-place operations after compute, altering already computed results
value = apply_to_collection(value, Tensor, lambda x: x.clone())

if self.compute_with_cache:
self._computed = value
Expand Down
19 changes: 19 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,29 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target
for key in res_cg:
assert torch.allclose(res_cg[key], res_without_cg[key])

# Check if second compute is the same
res_cg2 = m.compute()
for key in res_cg2:
assert torch.allclose(res_cg[key], res_cg2[key])

if with_reset:
m.reset()
m2.reset()

# Test if a second compute without a reset is the same
m.reset()
m.update(preds, target)
res_cg = m.compute()
# Simulate different preds by simply inversing them
m.update(1 - preds, target)
res_cg2 = m.compute()
# Now check if the results from the first compute are different from the second
for key in res_cg:
# A different shape is okay, therefore skip (this happens for multidim_average="samplewise")
if res_cg[key].shape != res_cg2[key].shape:
continue
assert not torch.all(res_cg[key] == res_cg2[key])

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check states are copied instead of passed by ref when a single metric in the collection is access."""
Expand Down

0 comments on commit 6a23b38

Please sign in to comment.