Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove usage of local variables related with model parallel and move … #13039

Closed
wants to merge 1 commit into from
Closed

Conversation

hyunwoongko
Copy link
Contributor

What does this PR do?

This PR is related with model parallel integration from Parallelformers.
You can check detail of PR here: tunib-ai/parallelformers#11 (comment)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stas00

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR, @hyunwoongko

This is a general comment to the whole effort of integrating parallelformers into transformers.

  1. Let's always start with one model, review, agree and when happy replicate to the rest of models in the same PR.
  2. start with more than one model if that other model is different from others and requires a special different change

This will make both the reviewing process easier and it'll be much less work for you, should changes be requested.

So to stick to this proposal please don't adjust any files but the ones that were commented on, until we all agree and then all other files can follow suite.

  1. Please always paste the description of the proposed change to the PR, rather than linking to other projects, since the latter may disappear in the future and it'll be difficult to understand why a certain changes was done down the road.

As such I'm going to comment only on the first 2 uniques models.

@sgugger, a question to you this is going to be potentially very extensive process. Do you prefer a step by step process which would be easy to review / integrate - like the proposed PR, or would you rather have a massive single PR? The problem is that it'll take touching on dozens of files and I think it'd be difficult to review. To me doing it in stages is a simpler approach to all involved.

Comment on lines +252 to +255
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, this is new to transformers - the models now become rank aware. should we handle this in a clean way by extending the model's API to get local_rank? actually I think this really should be returning a ready device object instead. What do you think?

and surely the verbatim code shouldn't be replicated everywhere as it doesn't contribute to model's readability and should be abstracted into an API - put it in the super-class or having a util function?

Copy link
Contributor Author

@hyunwoongko hyunwoongko Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok Firstly, We can start models that doesn't have token type ids like gpt2 or gpt neo.
as you said, It's easier. And next time, we can extend model that uses token type ids like bert, roberta.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say that.

What I suggested is that we work out all the kinks in all models that are unique, and any similar models get worked out afterwards. So if you're finding yourself replicating the same code - it's not unique - if that's a good rough guideline to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I think I misunderstood what you said. (As you know, I'm not good at English because I've never lived in an English-speaking country.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to define a new abstract class or write a utility function than to use the os module directly.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't worry about that, even those with great grasp of a certain language misunderstand each other very often. It's normal.

So we will just try to explain again in a different way until we understand each other. And code samples are the best way to communicate at times ;)

So to try again, what I was proposing is to pick all the models that require unique handling. And once those are figured out we can replicate all the similar ones. e.g. we have some models which are 95% identical to 5-10 others.

Alternatively, we can also take a very different approach. We can pick just one model - say T5 or GPT2 and completely port it, and then do a few more models, etc, etc. The drawback in this approach is that it'd be more difficult to see how to generalize, but I think getting one model working sooner is more practical and we can get users to start experimenting and report flaws sooner than later. Also it'll be much easier to see that we are doing the right thing, and of course tests will be needed, and we can test right away.

@sgugger, what's your take on this? Do the integration in stages across all models? Or do 1-2 most popular models, and then replay to other models, generalizing on the way where needed?

@@ -249,10 +249,14 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to move the variable a moment later to another device, as proposed next, this is wasteful.

We should get the device figured out first and then create the variable directly on the target device.

apologies if this comment is confusing I'm referring to lines 250-254. so this line and the 3 code lines after it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. we can add another one else statement for this.
But let's start with the model without token type id as you said above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why adding an else statement, I don't quite follow?

I suggested we figure out the device first, and then in one go create the variable on the right device. i.e. your rank code goes before the creation, the correct device is set and then there is no .to() to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of following code.

        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            elif model_parallel:  # <-- new variable
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=os.getenv("LOCAL_RANK"))
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If model parallel is not applied, the LOCAL_RANK variable does not exist.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but still I'd prefer to abstract os.getenv("LOCAL_RANK")) and include meaningful exceptions should it be invalid for whatever reason. i.e. transfromers needs to have a defined API to get the rank and not rely just on env var.

It'd also hide all that checking if it's not defined. Let's perhaps start with a helper util in modeling_utils.py to add the abstraction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. It would be good to implement mpu like nvidia megatron and provide it as a utility function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will absolutely need to implement MPU anyway, so let's to do it right from the get-going!

You have a choice of Megatron or deepspeed MPU versions, I may have seen some other.

I think I may have even started to port one while trying to plug my PP PR into Deepspeed's 3D. yes, I did:
https://github.com/huggingface/transformers/pull/9765/files#diff-48e672da3865f77a2e1d38954e8e075c0f1e02c7306f163847b9b8ecc56ede24
I see I took it from megatron.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can re-use some other parts of my work, but things have probably changed quite a lot in transformers to try to rescue much. there was some of the mpu use in the modeling_t5.py, but if I remember it wasn't quite finished.

@@ -253,7 +253,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do that, I propose we change the earlier code:

- bsz, tgt_len, embed_dim = hidden_states.size()
+ bsz, tgt_len, _ = hidden_states.size()

so that there will be no confusing embed_dim local variable hanging around.

that way we know we already use self.embed_dim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great.

@@ -762,7 +762,7 @@ def forward(
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
.reshape(batch_size, tgt_len, self.hidden_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite flagged this:

AttributeError: 'ProphetNetAttention' object has no attribute 'hidden_size'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want config.hidden_size here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or stash config.hidden_size into self.hidden_size in init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check again

Copy link
Contributor Author

@hyunwoongko hyunwoongko Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyunwoongko
Copy link
Contributor Author

Most of these modifications are encoder-decoder models from Bart's code and encoder models that has token type id. As you said, it is difficult to work on all models at once, so I will exclude the case where the model needs to be modified. I also agree that modifying multiple models at the same time makes it difficult to test. First, let's start with one decoder model like GPT-Neo. I will close this PR and upload a new one soon.

@hyunwoongko hyunwoongko closed this Aug 9, 2021
@stas00
Copy link
Contributor

stas00 commented Aug 9, 2021

One more note: besides the dozens of models we also have a template. In this case it's mostly: https://github.com/huggingface/transformers/blob/master/templates/adding_a_new_model/cookiecutter-template-%7B%7Bcookiecutter.modelname%7D%7D/modeling_%7B%7Bcookiecutter.lowercase_modelname%7D%7D.py so when all is happy here, please let's not forget to apply the changes there as well.

@sgugger
Copy link
Collaborator

sgugger commented Aug 10, 2021

I would like an approach that that does one model first, so we can clearly comment on the design, then all models after (unless it's very different for each model in which case, similar models by similar models if that makes sense).

As for the changes in themselves, I would need a clear explanation as to why the token_type_ids device need to be changed from the position_ids device. That kind of code should not be present in the modeling files as is, as people adding or tweaking models won't need/understand it. We can abstract away things in PreTrainedModel as you suggest @stas00, that seems like a better approach. Or maybe a method that creates those token_type_ids properly, at the very least.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants