Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small change to Wav2Vec2 model to support Tensor-Parallelism with DeepSpeed #14298

Conversation

RezaYazdaniAminabadi
Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi commented Nov 5, 2021

What does this PR do?

This PR adds a minor modification to BartAttention and its copies to support tensor-parallelism with DeepSpeed. This relates to this PR on DeepSpeed side.

Please see the added comments in the code that explain the change.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stas00

@stas00 stas00 requested a review from sgugger November 6, 2021 00:27
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have been discussing this offline - paving the road for dynamic TP-support in transformers via deepspeed. It's going to be super-neat.

Thank you, Reza!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR!
I made one comment on the first model that can be replicated to all the other ones.

Could you also add a test to make sure the feature works (we might not be able to run it on our 2 GPUs machine, but a 4 GPUs one is coming).

Comment on lines 174 to 176

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is useful when reading the new code. It creates more confusion than help, only the next one is really important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I can remove this.

Comment on lines 264 to 268
# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a 119 char limits so you can use more horizontal space :-)

Also, I suggest the following change, more to the point:

Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be partitioned across GPUs when using tensor-parallelism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will reformat this :)

@stas00
Copy link
Contributor

stas00 commented Nov 8, 2021

Could you also add a test to make sure the feature works (we might not be able to run it on our 2 GPUs machine, but a 4 GPUs one is coming).

We will have a full battery of tests for Deepspeed Inference. I will take care of this, Sylvain. The plan is to have a model zoo style test - identical to Deepspeed ZeRO tests, so to cover as many models as possible. (there will be also Deepspeed ZeRO Inference tests #14253, which is different from Deepspeed Inference)

We didn't feel a test was needed for this particular PR since it doesn't change anything for a normal application.

@sgugger
Copy link
Collaborator

sgugger commented Nov 9, 2021

You now need to run make style on your branch to fix the code quality issue :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all your work on this!

@sgugger sgugger merged commit a503012 into huggingface:master Nov 9, 2021
@RezaYazdaniAminabadi
Copy link
Contributor Author

Thanks @sgugger and @stas00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants