Skip to content

Commit

Permalink
WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo (m…
Browse files Browse the repository at this point in the history
…icrosoft#5590)

We have been encountered an accuracy issue when running Torch compile +
zero3 + activation checkpointing. Specifically some grads gets is zeroed
(running without torch compile, this issue is not encountered). This
issue was also reproduced by Umesh Chand from the DS team. We found that
in the Pytorch repo torch compile has been specifically disabled using
the label: @torch._disable_dynamo()
reference to the WA in the Pytorch repo
(https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324)
this indicates that there is some issue with torch compile and
checkpointing (not necessarily DS related).

given that the checkpointing function in DeepSpeed is based on the
Pytorch function, We propose to adopt this WA to ensure correct behavior
(it can be removed later if the underlying issue is fixed)
Note: this shouldn't impact non-troch compile cases.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
NirSonnenschein and tjruwase authored Jun 10, 2024
1 parent 31a57fa commit 6e2899f
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import compiler

# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
Expand Down Expand Up @@ -987,6 +988,7 @@ def after_backward_hook(_nonuse_grads):
return tuple(all_outputs)


@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
Expand Down

0 comments on commit 6e2899f

Please sign in to comment.