Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Torch-compile-Z3-act-apt accuracy issue Pytorch WA
We have been encountered an accuracy issue when running Torch compile + zero3 + activation checkpoiting. Sepcifically some grads gets is zeroed (running without torch compile, this issue is not encountered). This issue was also reproduced by Umesh Chand from the DS team. We found that in the Pytorch repo torch compile has been specifically disabled using the label: @torch._disable_dynamo() reference to the WA in the pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpoiting (not necessarily DS related). given that the checkpointing function in Deepspeed is based on the Pytorch function, We propose to adopt this WA to ensure correct behavior (it can be removed later if the underlying issue is fixed) Note: this shouldn't impact non-troch compile cases.
- Loading branch information