Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo #5590

Merged

Conversation

NirSonnenschein
Copy link
Contributor

We have been encountered an accuracy issue when running Torch compile + zero3 + activation checkpointing. Specifically some grads gets is zeroed (running without torch compile, this issue is not encountered). This issue was also reproduced by Umesh Chand from the DS team. We found that in the Pytorch repo torch compile has been specifically disabled using the label: @torch._disable_dynamo()
reference to the WA in the Pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpointing (not necessarily DS related).

given that the checkpointing function in DeepSpeed is based on the Pytorch function, We propose to adopt this WA to ensure correct behavior (it can be removed later if the underlying issue is fixed)
Note: this shouldn't impact non-troch compile cases.

@tjruwase
Copy link
Contributor

@umchand, @tohtana FYI

@umchand umchand assigned umchand and unassigned umchand May 30, 2024
@umchand umchand self-requested a review May 30, 2024 20:05
@tohtana
Copy link
Contributor

tohtana commented May 31, 2024

@NirSonnenschein Thank you for the PR.

It seems old versions of PyTorch don't have @torch._disable_dynamo(). Can you try deepspeed.runtime.compiler.disable instead?
https://github.com/microsoft/DeepSpeed/blob/2fc702ed9f84f3c88179cb713bb9f4effbbdba33/deepspeed/runtime/compiler.py#L19C5-L19C12

We have been encountered an accuracy issue when running
Torch compile + zero3 + activation checkpoiting. Sepcifically some grads gets is zeroed
(running without torch compile, this issue is not encountered). This issue was also reproduced
by Umesh Chand from the DS team. We found that in the Pytorch repo
torch compile has been specifically disabled using the label:
@torch._disable_dynamo()
reference to the WA in the pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpoiting (not necessarily DS related).

given that the checkpointing function in Deepspeed is based on the
Pytorch function, We propose to adopt this WA to ensure correct
behavior (it can be removed later if the underlying issue is fixed)
Note: this shouldn't impact non-troch compile cases.
@NirSonnenschein NirSonnenschein force-pushed the torch_compile_Z3_WA_from_pytorch branch from e068adf to bc8f511 Compare June 4, 2024 09:03
@NirSonnenschein
Copy link
Contributor Author

@NirSonnenschein Thank you for the PR.

It seems old versions of PyTorch don't have @torch._disable_dynamo(). Can you try deepspeed.runtime.compiler.disable instead? https://github.com/microsoft/DeepSpeed/blob/2fc702ed9f84f3c88179cb713bb9f4effbbdba33/deepspeed/runtime/compiler.py#L19C5-L19C12

thanks for the comment, I've uploaded a new version which should fix this.

@NirSonnenschein
Copy link
Contributor Author

Hi @tohtana ,
would it be possible to re-run CI to check that the new version passes?

@NirSonnenschein
Copy link
Contributor Author

Hi @mrwyattii,
Would it be possible to review this PR when you have time ?

@tjruwase tjruwase added this pull request to the merge queue Jun 9, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jun 9, 2024
@NirSonnenschein
Copy link
Contributor Author

thanks @tjruwase,
the failure in the merge queue tests doesn't seem related to this commit.
the error seems to be a HTTP error getting a model from huggingface:
E huggingface_hub.utils._errors.HfHubHTTPError: 429 Client Error: Too Many Requests for url: https://huggingface.co/api/models?cursor=eyJfaWQiOnsiJGd0IjoiNjY0ZTIxNTAyZmI1MWJjZjFlNzdkZjBkIn19

@tjruwase tjruwase added this pull request to the merge queue Jun 10, 2024
Merged via the queue into microsoft:master with commit 6e2899f Jun 10, 2024
12 checks passed
sfc-gh-reyazda pushed a commit to Snowflake-Labs/DeepSpeed that referenced this pull request Jun 10, 2024
…icrosoft#5590)

We have been encountered an accuracy issue when running Torch compile +
zero3 + activation checkpointing. Specifically some grads gets is zeroed
(running without torch compile, this issue is not encountered). This
issue was also reproduced by Umesh Chand from the DS team. We found that
in the Pytorch repo torch compile has been specifically disabled using
the label: @torch._disable_dynamo()
reference to the WA in the Pytorch repo
(https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324)
this indicates that there is some issue with torch compile and
checkpointing (not necessarily DS related).

given that the checkpointing function in DeepSpeed is based on the
Pytorch function, We propose to adopt this WA to ensure correct behavior
(it can be removed later if the underlying issue is fixed)
Note: this shouldn't impact non-troch compile cases.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants