From 37f8166c589be498f284f07b60fde72c04ea27a4 Mon Sep 17 00:00:00 2001
From: Liyang90 <liyanglu@google.com>
Date: Mon, 12 Dec 2022 10:20:36 -0800
Subject: [PATCH 1/4] Fix DDP on XLA

Fix DDP on XLA by inserting `xm.reduce_gradients` in `closure` instead of using `xm.optimizer_step`.
Adding a `xm.mark_step()` call after `optimizer.step` for performance.
---
 src/pytorch_lightning/plugins/precision/tpu.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8fca04..3f1b59059e695 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
             raise ModuleNotFoundError(str(_XLA_AVAILABLE))
         super().__init__(*args, **kwargs)
 
+    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+        import torch_xla.core.xla_model as xm
+
+        closure_result = closure()
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
     def optimizer_step(  # type: ignore[override]
         self,
         optimizer: Optimizable,
@@ -39,8 +46,10 @@ def optimizer_step(  # type: ignore[override]
     ) -> Any:
         import torch_xla.core.xla_model as xm
 
+        closure = partial(self._tpu_wrap_closure, optimizer, closure)
         closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
+        xm.mark_step()
         skipped_backward = closure_result is None
         # in manual optimization, the closure does not return a value
         if model.automatic_optimization and skipped_backward:

From de01b1a263fbd03f75ba251a4d4b548f409357ef Mon Sep 17 00:00:00 2001
From: awaelchli <aedu.waelchli@gmail.com>
Date: Mon, 12 Dec 2022 22:39:09 +0100
Subject: [PATCH 2/4] fix mypy error

---
 src/pytorch_lightning/plugins/precision/tpu.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index 3f1b59059e695..ef98cd8a436f2 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,7 +29,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
             raise ModuleNotFoundError(str(_XLA_AVAILABLE))
         super().__init__(*args, **kwargs)
 
-    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+    def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
         import torch_xla.core.xla_model as xm
 
         closure_result = closure()

From a9ad4690c755cb1d7a790dbf2e2427afb549ff39 Mon Sep 17 00:00:00 2001
From: awaelchli <aedu.waelchli@gmail.com>
Date: Mon, 12 Dec 2022 22:39:15 +0100
Subject: [PATCH 3/4] add changelog entry

---
 src/pytorch_lightning/CHANGELOG.md | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md
index 60afab1d0bcc9..fdc4d0c47a0cc 100644
--- a/src/pytorch_lightning/CHANGELOG.md
+++ b/src/pytorch_lightning/CHANGELOG.md
@@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))
 
 
+- Fixed the incorrect optimizer step synchronization when running across multiple TPU devices ([#16020](https://github.com/Lightning-AI/lightning/pull/16020))
+
+
 ## [1.8.4] - 2022-12-08
 
 ### Changed

From 101c6c2a67e715ce4b4f5849d7ff60de071e4173 Mon Sep 17 00:00:00 2001
From: awaelchli <aedu.waelchli@gmail.com>
Date: Sun, 1 Jan 2023 19:58:56 +0100
Subject: [PATCH 4/4] add test

---
 .../plugins/precision/test_tpu.py             | 28 +++++++++++++++++++
 1 file changed, 28 insertions(+)
 create mode 100644 tests/tests_pytorch/plugins/precision/test_tpu.py

diff --git a/tests/tests_pytorch/plugins/precision/test_tpu.py b/tests/tests_pytorch/plugins/precision/test_tpu.py
new file mode 100644
index 0000000000000..a44ab5bc08b12
--- /dev/null
+++ b/tests/tests_pytorch/plugins/precision/test_tpu.py
@@ -0,0 +1,28 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+from unittest.mock import Mock
+
+from pytorch_lightning.plugins import TPUPrecisionPlugin
+from tests_pytorch.helpers.runif import RunIf
+
+
+@RunIf(tpu=True)
+def test_optimizer_step_calls_mark_step():
+    plugin = TPUPrecisionPlugin()
+    optimizer = Mock()
+    with mock.patch("torch_xla.core.xla_model") as xm_mock:
+        plugin.optimizer_step(optimizer=optimizer, model=Mock(), optimizer_idx=0, closure=Mock())
+    optimizer.step.assert_called_once()
+    xm_mock.mark_step.assert_called_once()