Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP training on TPU is not working #15878

Closed
Liyang90 opened this issue Nov 30, 2022 · 7 comments
Closed

DDP training on TPU is not working #15878

Liyang90 opened this issue Nov 30, 2022 · 7 comments
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working
Milestone

Comments

@Liyang90
Copy link
Contributor

Liyang90 commented Nov 30, 2022

Bug description

PyTorch-Lightning DDP training on TPU might be broken for “automatic_optimization” mode.

“automatic_optimization” mode is the default mode for a LightningModule to be trained by the Trainer. Users only need to define the training forward pass in the module and the Trainer would automatically do the backward pass and optimization step correctly.

On TPU (PT/XLA), DDP is usually achieved by calling xm.optimizer_step() for optimization step, which would add a gradient all_reduce op before calling optimizer.step(). In PyTorch-Lightning, it is done in TPUPrecisionPlugin class as:

closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})

However, in “automatic_optimization” mode, PyTorch-Lightning actually puts the forward pass and backward pass into the closure callable given to the optimizer.step(), so forward and backward happen within optimizer.step(). What ended up happening is, in an iteration of a batch:

  1. All_reduce gradients
  2. Forward pass
  3. Zero gradients
  4. Backward pass
  5. Optimizer step

The all_reduce would be a no_op, and the gradients are not being synchronized between DDP processes before being applied to the models. What should happen instead is:

  1. Forward pass
  2. Zero gradients
  3. Backward pass
  4. All_reduce gradients
  5. Optimizer step

The gradient all_reduce op needs to be inserted into the closure. A possible fix could be:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..2dda36d3f 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
            raise ModuleNotFoundError(str(_XLA_AVAILABLE))
        super().__init__(*args, **kwargs)
+    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+        import torch_xla.core.xla_model as xm
+
+        closure_result = closure()
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
    def optimizer_step(  # type: ignore[override]
        self,
        optimizer: Optimizable,
@@ -39,8 +46,9 @@ class TPUPrecisionPlugin(PrecisionPlugin):
    ) -> Any:
        import torch_xla.core.xla_model as xm
+        closure = partial(self._tpu_wrap_closure, optimizer, closure)
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

A comparison between with and without the patch is done on the modified MNIST TPU tutorial, with BATCH_SIZE = 128: Tensorboard log. From the training loss curve in it, the model converges faster with the fixed codes, as the gradients are correctly reduced and models are synced across processes.

IR graphs can also be dumped before and after the fix. Before the fix, a xla::cross_replica_sum() op cannot be found in an iteration, while after the fix, it correctly appears after the backward ops.

How to reproduce the bug

MNIST TPU tutorial

Error messages and logs


# Error messages and logs here please

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): PyTorch-Lightning
#- PyTorch Lightning Version (e.g., 1.5.0): master
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.14
#- Python version (e.g., 3.9):
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: None
#- TPU models and configuration: Google Cloud TPU V3-8
#- How you installed Lightning(`conda`, `pip`, source): source
#- Running environment of LightningApp (e.g. local, cloud): cloud

More info

No response

@Liyang90 Liyang90 added the needs triage Waiting to be triaged by maintainers label Nov 30, 2022
@carmocca carmocca added bug Something isn't working accelerator: tpu Tensor Processing Unit and removed needs triage Waiting to be triaged by maintainers labels Dec 1, 2022
@carmocca carmocca added this to the v1.8.x milestone Dec 1, 2022
@carmocca
Copy link
Contributor

carmocca commented Dec 1, 2022

Thank you so much! Do you want to open a PR implementing your solution?

@carmocca
Copy link
Contributor

carmocca commented Dec 1, 2022

Looking at the source code (https://github.com/pytorch/xla/blob/6e6bb07e696c51a555a7d33433508ba236703a35/torch_xla/core/xla_model.py#L1027-L1032) your suggestion LGTM. You should still call the parent's _wrap_closure method

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..30c64751b 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
             raise ModuleNotFoundError(str(_XLA_AVAILABLE))
         super().__init__(*args, **kwargs)
 
+    def _wrap_closure(self, model: "pl.LightningModule", optimizer: Optimizable, **kwargs: Any) -> Any:
+        closure_result = super()._wrap_closure(model, optimizer, **kwargs)
+        import torch_xla.core.xla_model as xm
+
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
     def optimizer_step(  # type: ignore[override]
         self,
         optimizer: Optimizable,
@@ -37,10 +44,8 @@ class TPUPrecisionPlugin(PrecisionPlugin):
         closure: Callable[[], Any],
         **kwargs: Any,
     ) -> Any:
-        import torch_xla.core.xla_model as xm
-
         closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
         skipped_backward = closure_result is None
         # in manual optimization, the closure does not return a value
         if model.automatic_optimization and skipped_backward:

@Liyang90
Copy link
Contributor Author

Liyang90 commented Dec 1, 2022

Shouldn't grad clipping be done after grad all_reduce?

@carmocca
Copy link
Contributor

carmocca commented Dec 1, 2022

Yes. My bad!

@Liyang90
Copy link
Contributor Author

Liyang90 commented Dec 1, 2022

Thank you so much! Do you want to open a PR implementing your solution?

I notice there are other places that have called xm.optimizer_step(), such as in lightning, lightning_lite, and pytorch_lightning. I don't know enough how to incorporate this change to all the places that have similar issue. So I hope someone know more about the package structure could help.

@carmocca
Copy link
Contributor

carmocca commented Dec 12, 2022

AFAIK Lite does not suffer from this issue because it doesn't need to run the closure separately. This fix is only relevant for the src/pytorch_lightning/plugins/precision/tpu.py file. You can get started on a PR and we'll continue reviewing there.

@Liyang90 Liyang90 mentioned this issue Dec 12, 2022
12 tasks
@Liyang90
Copy link
Contributor Author

Liyang90 commented Jan 3, 2023

The fix is merged to master.

@Liyang90 Liyang90 closed this as completed Jan 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants