diff --git a/CHANGELOG.md b/CHANGELOG.md index 9be496340e2..a6da9e6eed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug to prevent users from going into a infinite loop if trying to iterate of a single metric ([#1320](https://github.com/Lightning-AI/metrics/pull/1320)) +- Fixed bug when `compute_on_cpu` arg used together with `forward` method ([#1333](https://github.com/Lightning-AI/metrics/pull/1333)) + + ## [0.10.2] - 2022-10-31 ### Changed diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 26c3b9ac529..207a3fa701a 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -288,6 +288,8 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self._computed = None self._enable_grad = False self.compute_on_cpu = _temp_compute_on_cpu + if self.compute_on_cpu: + self._move_list_states_to_cpu() return batch_val @@ -325,6 +327,8 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: self._computed = None self._enable_grad = False self.compute_on_cpu = _temp_compute_on_cpu + if self.compute_on_cpu: + self._move_list_states_to_cpu() return batch_val diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 3ea416dc059..ad00f910264 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -471,3 +471,19 @@ def test_no_iteration_allowed(): with pytest.raises(NotImplementedError, match="Metrics does not support iteration."): for m in metric: continue + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +@pytest.mark.parametrize("method", ["forward", "update"]) +def test_compute_on_cpu_arg_forward(method): + metric = DummyListMetric(compute_on_cpu=True) + x = torch.randn(10).cuda() + if method == "update": + metric.update(x) + metric.update(x) + else: + _ = metric(x) + _ = metric(x) + val = metric.compute() + assert all(str(v.device) == "cpu" for v in val) + assert all(torch.allclose(v, x.cpu()) for v in val)