diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 2a21cf7ca17a..40b58d425a58 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -29,6 +29,7 @@ from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.accelerator import get_accelerator +from deepspeed.runtime import compiler # DeepSpeed Checkpointing Enabled or Disabled deepspeed_checkpointing_enabled = False @@ -987,6 +988,7 @@ def after_backward_hook(_nonuse_grads): return tuple(all_outputs) +@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue def checkpoint(function, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint. """