Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP on XLA #16020

Merged
merged 9 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))


- Fixed the incorrect optimizer step synchronization when running across multiple TPU devices ([#16020](https://github.com/Lightning-AI/lightning/pull/16020))


## [1.8.6] - 2022-12-21

- minor cleaning
Expand Down
11 changes: 10 additions & 1 deletion src/pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm

closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
Expand All @@ -39,8 +46,10 @@ def optimizer_step( # type: ignore[override]
) -> Any:
import torch_xla.core.xla_model as xm

closure = partial(self._tpu_wrap_closure, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
closure_result = optimizer.step(closure=closure, **kwargs)
xm.mark_step()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock

from pytorch_lightning.plugins import TPUPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf


@RunIf(tpu=True)
def test_optimizer_step_calls_mark_step():
plugin = TPUPrecisionPlugin()
optimizer = Mock()
with mock.patch("torch_xla.core.xla_model") as xm_mock:
plugin.optimizer_step(optimizer=optimizer, model=Mock(), optimizer_idx=0, closure=Mock())
optimizer.step.assert_called_once()
xm_mock.mark_step.assert_called_once()