-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DDP training on TPU is not working #15878
Comments
Thank you so much! Do you want to open a PR implementing your solution? |
Looking at the source code (https://github.com/pytorch/xla/blob/6e6bb07e696c51a555a7d33433508ba236703a35/torch_xla/core/xla_model.py#L1027-L1032) your suggestion LGTM. You should still call the parent's diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..30c64751b 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
+ def _wrap_closure(self, model: "pl.LightningModule", optimizer: Optimizable, **kwargs: Any) -> Any:
+ closure_result = super()._wrap_closure(model, optimizer, **kwargs)
+ import torch_xla.core.xla_model as xm
+
+ xm.reduce_gradients(optimizer)
+ return closure_result
+
def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
@@ -37,10 +44,8 @@ class TPUPrecisionPlugin(PrecisionPlugin):
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
- import torch_xla.core.xla_model as xm
-
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
- closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+ closure_result = optimizer.step(closure=closure, **kwargs)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward: |
Shouldn't grad clipping be done after grad all_reduce? |
Yes. My bad! |
I notice there are other places that have called |
AFAIK Lite does not suffer from this issue because it doesn't need to run the closure separately. This fix is only relevant for the |
The fix is merged to master. |
Bug description
PyTorch-Lightning DDP training on TPU might be broken for “automatic_optimization” mode.
“automatic_optimization” mode is the default mode for a LightningModule to be trained by the Trainer. Users only need to define the training forward pass in the module and the Trainer would automatically do the backward pass and optimization step correctly.
On TPU (PT/XLA), DDP is usually achieved by calling
xm.optimizer_step()
for optimization step, which would add a gradient all_reduce op before callingoptimizer.step()
. In PyTorch-Lightning, it is done in TPUPrecisionPlugin class as:However, in “automatic_optimization” mode, PyTorch-Lightning actually puts the forward pass and backward pass into the
closure
callable given to theoptimizer.step()
, so forward and backward happen withinoptimizer.step()
. What ended up happening is, in an iteration of a batch:The all_reduce would be a no_op, and the gradients are not being synchronized between DDP processes before being applied to the models. What should happen instead is:
The gradient all_reduce op needs to be inserted into the closure. A possible fix could be:
A comparison between with and without the patch is done on the modified MNIST TPU tutorial, with
BATCH_SIZE = 128
: Tensorboard log. From the training loss curve in it, the model converges faster with the fixed codes, as the gradients are correctly reduced and models are synced across processes.IR graphs can also be dumped before and after the fix. Before the fix, a
xla::cross_replica_sum()
op cannot be found in an iteration, while after the fix, it correctly appears after the backward ops.How to reproduce the bug
MNIST TPU tutorial
Error messages and logs
Environment
More info
No response
The text was updated successfully, but these errors were encountered: