Skip to content

Commit

Permalink
Fix DDP on XLA (#16020)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
2 people authored and carmocca committed Jan 4, 2023
1 parent a36801c commit d8e3435
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))

- Fixed the incorrect optimizer step synchronization when running across multiple TPU devices ([#16020](https://github.com/Lightning-AI/lightning/pull/16020))


- Fixed a type error when dividing the chunk size in the ColossalAI strategy ([#16212](https://github.com/Lightning-AI/lightning/pull/16212))


Expand Down
11 changes: 10 additions & 1 deletion src/pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm

closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
Expand All @@ -39,8 +46,10 @@ def optimizer_step( # type: ignore[override]
) -> Any:
import torch_xla.core.xla_model as xm

closure = partial(self._tpu_wrap_closure, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
closure_result = optimizer.step(closure=closure, **kwargs)
xm.mark_step()
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock

from pytorch_lightning.plugins import TPUPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf


@RunIf(tpu=True)
def test_optimizer_step_calls_mark_step():
plugin = TPUPrecisionPlugin()
optimizer = Mock()
with mock.patch("torch_xla.core.xla_model") as xm_mock:
plugin.optimizer_step(optimizer=optimizer, model=Mock(), optimizer_idx=0, closure=Mock())
optimizer.step.assert_called_once()
xm_mock.mark_step.assert_called_once()

0 comments on commit d8e3435

Please sign in to comment.