Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set supports_gradient_checkpointing = False in SwitchTransformer #35249

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

nhamanasu
Copy link
Contributor

@nhamanasu nhamanasu commented Dec 12, 2024

What does this PR do?

Recently, CircleCI's tests_torch sometimes fails (Reference: tests_torch - Failed), because some models which don't support gradient_checkpointing or have MoE module like SwitchTransformer will occasionally have weights without gradients.

I found maybe-related PR: #34806, and the situation described in the PR was already fixed (pytest-dev/pytest-subtests#169 was already merged).

In #34806, @ydshieh said: TODO (ydshieh): use skipTest once pytest-dev/pytest-subtests/pull/169 is merged, so I reverted the stopgap fix.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @ydshieh @muellerzr @SunMarc

This was referenced Dec 12, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Dec 12, 2024

Hi @nhamanasu

Thank you for the PR. Although pytest-dev/pytest-subtests#169 is merged, we will still need to review the impact of installing and using pytest-subtests as it will change the report format.

In the meantime, #34806 should not be the cause of failing check_training_gradient_checkpointing as it skips some cases (using continue). It means there are some model_class that are still failing this test.

Change it to skipTest along with pytest-dev/pytest-subtests#169 won't help here, as that is somehow equivalent continue.

Could you check if SwitchTransformers is MOE type models? Maybe this model just doesn't support gradient checkpoint and we need to set supports_gradient_checkpointing = False for this model

@nhamanasu
Copy link
Contributor Author

nhamanasu commented Dec 13, 2024

Hi @ydshieh

Thank you for your response! I got the situation (sorry for my misunderstanding).

Related to SwitchTransformer, it currently has supports_gradient_checkpointing = True.

And also, SwitchTransformer has experts in MLP modules:

I think the last part should be the cause of non-gradient weights. So as you suggested, I think we need to set supports_gradient_checkpointing = False for this model.

@nhamanasu
Copy link
Contributor Author

nhamanasu commented Dec 13, 2024

  • CircleCI's tests_generate has failed, but it'll be resolved in this PR: skip Fuyu from test_generate #35246
  • tests_torch (The target of this PR) has passed after setting supports_gradient_checkpointing = False in SwitchTransformersPreTrainedModel.

@@ -769,7 +769,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):

config_class = SwitchTransformersConfig
base_model_prefix = "switch_transformers"
supports_gradient_checkpointing = True
supports_gradient_checkpointing = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can simply remove this line as False it the default value defined in

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):

@@ -863,8 +863,7 @@ def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=No
]
or not model_class.supports_gradient_checkpointing
):
# TODO (ydshieh): use `skipTest` once pytest-dev/pytest-subtests/pull/169 is merged
# self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert the change here :-)

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 13, 2024

Thanks. Just 2 nits and we are ready to go

@nhamanasu
Copy link
Contributor Author

Thank you! I've reflected your comments.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the iteration!

@ydshieh ydshieh requested a review from ArthurZucker December 13, 2024 09:37
@nhamanasu nhamanasu changed the title revert skipTest in check_training_gradient_checkpointing set supports_gradient_checkpointing = False in SwitchTransformer Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants