Skip to content

Commit

Permalink
Fix compute_on_cpu arg + forward method (#1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Nov 14, 2022
1 parent 3636182 commit 3188728
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug to prevent users from going into a infinite loop if trying to iterate of a single metric ([#1320](https://github.com/Lightning-AI/metrics/pull/1320))


- Fixed bug when `compute_on_cpu` arg used together with `forward` method ([#1333](https://github.com/Lightning-AI/metrics/pull/1333))


## [0.10.2] - 2022-10-31

### Changed
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
self._computed = None
self._enable_grad = False
self.compute_on_cpu = _temp_compute_on_cpu
if self.compute_on_cpu:
self._move_list_states_to_cpu()

return batch_val

Expand Down Expand Up @@ -325,6 +327,8 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
self._computed = None
self._enable_grad = False
self.compute_on_cpu = _temp_compute_on_cpu
if self.compute_on_cpu:
self._move_list_states_to_cpu()

return batch_val

Expand Down
16 changes: 16 additions & 0 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,19 @@ def test_no_iteration_allowed():
with pytest.raises(NotImplementedError, match="Metrics does not support iteration."):
for m in metric:
continue


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("method", ["forward", "update"])
def test_compute_on_cpu_arg_forward(method):
metric = DummyListMetric(compute_on_cpu=True)
x = torch.randn(10).cuda()
if method == "update":
metric.update(x)
metric.update(x)
else:
_ = metric(x)
_ = metric(x)
val = metric.compute()
assert all(str(v.device) == "cpu" for v in val)
assert all(torch.allclose(v, x.cpu()) for v in val)

0 comments on commit 3188728

Please sign in to comment.