Skip to content

Commit

Permalink
Torch-compile-Z3-act-apt accuracy issue Pytorch WA
Browse files Browse the repository at this point in the history
We have been encountered an accuracy issue when running
Torch compile + zero3 + activation checkpoiting. Sepcifically some grads gets is zeroed
(running without torch compile, this issue is not encountered). This issue was also reproduced
by Umesh Chand from the DS team. We found that in the Pytorch repo
torch compile has been specifically disabled using the label:
@torch._disable_dynamo()
reference to the WA in the pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpoiting (not necessarily DS related).

given that the checkpointing function in Deepspeed is based on the
Pytorch function, We propose to adopt this WA to ensure correct
behavior (it can be removed later if the underlying issue is fixed)
Note: this shouldn't impact non-troch compile cases.
  • Loading branch information
NirSonnenschein committed Jun 4, 2024
1 parent 2fc702e commit bc8f511
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import compiler

# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
Expand Down Expand Up @@ -987,6 +988,7 @@ def after_backward_hook(_nonuse_grads):
return tuple(all_outputs)


@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
Expand Down

0 comments on commit bc8f511

Please sign in to comment.