From 37f8166c589be498f284f07b60fde72c04ea27a4 Mon Sep 17 00:00:00 2001 From: Liyang90 <liyanglu@google.com> Date: Mon, 12 Dec 2022 10:20:36 -0800 Subject: [PATCH 1/4] Fix DDP on XLA Fix DDP on XLA by inserting `xm.reduce_gradients` in `closure` instead of using `xm.optimizer_step`. Adding a `xm.mark_step()` call after `optimizer.step` for performance. --- src/pytorch_lightning/plugins/precision/tpu.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py index efa61dd8fca04..3f1b59059e695 100644 --- a/src/pytorch_lightning/plugins/precision/tpu.py +++ b/src/pytorch_lightning/plugins/precision/tpu.py @@ -29,6 +29,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs) + def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any: + import torch_xla.core.xla_model as xm + + closure_result = closure() + xm.reduce_gradients(optimizer) + return closure_result + def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, @@ -39,8 +46,10 @@ def optimizer_step( # type: ignore[override] ) -> Any: import torch_xla.core.xla_model as xm + closure = partial(self._tpu_wrap_closure, optimizer, closure) closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) - closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) + closure_result = optimizer.step(closure=closure, **kwargs) + xm.mark_step() skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if model.automatic_optimization and skipped_backward: From de01b1a263fbd03f75ba251a4d4b548f409357ef Mon Sep 17 00:00:00 2001 From: awaelchli <aedu.waelchli@gmail.com> Date: Mon, 12 Dec 2022 22:39:09 +0100 Subject: [PATCH 2/4] fix mypy error --- src/pytorch_lightning/plugins/precision/tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py index 3f1b59059e695..ef98cd8a436f2 100644 --- a/src/pytorch_lightning/plugins/precision/tpu.py +++ b/src/pytorch_lightning/plugins/precision/tpu.py @@ -29,7 +29,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs) - def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any: + def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any: import torch_xla.core.xla_model as xm closure_result = closure() From a9ad4690c755cb1d7a790dbf2e2427afb549ff39 Mon Sep 17 00:00:00 2001 From: awaelchli <aedu.waelchli@gmail.com> Date: Mon, 12 Dec 2022 22:39:15 +0100 Subject: [PATCH 3/4] add changelog entry --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 60afab1d0bcc9..fdc4d0c47a0cc 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) +- Fixed the incorrect optimizer step synchronization when running across multiple TPU devices ([#16020](https://github.com/Lightning-AI/lightning/pull/16020)) + + ## [1.8.4] - 2022-12-08 ### Changed From 101c6c2a67e715ce4b4f5849d7ff60de071e4173 Mon Sep 17 00:00:00 2001 From: awaelchli <aedu.waelchli@gmail.com> Date: Sun, 1 Jan 2023 19:58:56 +0100 Subject: [PATCH 4/4] add test --- .../plugins/precision/test_tpu.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/tests_pytorch/plugins/precision/test_tpu.py diff --git a/tests/tests_pytorch/plugins/precision/test_tpu.py b/tests/tests_pytorch/plugins/precision/test_tpu.py new file mode 100644 index 0000000000000..a44ab5bc08b12 --- /dev/null +++ b/tests/tests_pytorch/plugins/precision/test_tpu.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import Mock + +from pytorch_lightning.plugins import TPUPrecisionPlugin +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(tpu=True) +def test_optimizer_step_calls_mark_step(): + plugin = TPUPrecisionPlugin() + optimizer = Mock() + with mock.patch("torch_xla.core.xla_model") as xm_mock: + plugin.optimizer_step(optimizer=optimizer, model=Mock(), optimizer_idx=0, closure=Mock()) + optimizer.step.assert_called_once() + xm_mock.mark_step.assert_called_once()