Skip to content

Commit

Permalink
Merge branch 'master' into optimizer_step/training_step
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren authored Nov 2, 2020
2 parents fabe833 + 19187d3 commit 840734e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))

- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313))

- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

## [1.0.4] - 2020-10-27
Expand Down
14 changes: 13 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ Example implementation:
def compute(self):
return self.correct.float() / self.total
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
**********
Metric API
**********
Expand Down Expand Up @@ -453,4 +466,3 @@ embedding_similarity [func]

.. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
:noindex:

6 changes: 3 additions & 3 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:

if options['logger'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[k] = self[k]._forward_cache.detach()
else:
result[k] = self[k]

Expand All @@ -281,7 +281,7 @@ def get_epoch_log_metrics(self) -> dict:

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[k] = self[k].compute().detach()
else:
result[k] = self[k]

Expand All @@ -307,7 +307,7 @@ def get_epoch_pbar_metrics(self):

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[k] = self[k].compute().detach()
else:
result[k] = self[k]

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def forward(self, *args, **kwargs):
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
self.update(*args, **kwargs)
with torch.no_grad():
self.update(*args, **kwargs)
self._forward_cache = None

if self.compute_on_step:
Expand Down

0 comments on commit 840734e

Please sign in to comment.