Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove usage of local variables related with model parallel and move … #13039

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,14 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to move the variable a moment later to another device, as proposed next, this is wasteful.

We should get the device figured out first and then create the variable directly on the target device.

apologies if this comment is confusing I'm referring to lines 250-254. so this line and the 3 code lines after it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. we can add another one else statement for this.
But let's start with the model without token type id as you said above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why adding an else statement, I don't quite follow?

I suggested we figure out the device first, and then in one go create the variable on the right device. i.e. your rank code goes before the creation, the correct device is set and then there is no .to() to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of following code.

        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            elif model_parallel:  # <-- new variable
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=os.getenv("LOCAL_RANK"))
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If model parallel is not applied, the LOCAL_RANK variable does not exist.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but still I'd prefer to abstract os.getenv("LOCAL_RANK")) and include meaningful exceptions should it be invalid for whatever reason. i.e. transfromers needs to have a defined API to get the rank and not rely just on env var.

It'd also hide all that checking if it's not defined. Let's perhaps start with a helper util in modeling_utils.py to add the abstraction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. It would be good to implement mpu like nvidia megatron and provide it as a utility function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will absolutely need to implement MPU anyway, so let's to do it right from the get-going!

You have a choice of Megatron or deepspeed MPU versions, I may have seen some other.

I think I may have even started to port one while trying to plug my PP PR into Deepspeed's 3D. yes, I did:
https://github.com/huggingface/transformers/pull/9765/files#diff-48e672da3865f77a2e1d38954e8e075c0f1e02c7306f163847b9b8ecc56ede24
I see I took it from megatron.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can re-use some other parts of my work, but things have probably changed quite a lot in transformers to try to rescue much. there was some of the mpu use in the modeling_t5.py, but if I remember it wasn't quite finished.


local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

Comment on lines +252 to +255
Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, this is new to transformers - the models now become rank aware. should we handle this in a clean way by extending the model's API to get local_rank? actually I think this really should be returning a ready device object instead. What do you think?

and surely the verbatim code shouldn't be replicated everywhere as it doesn't contribute to model's readability and should be abstracted into an API - put it in the super-class or having a util function?

Copy link
Contributor Author

@hyunwoongko hyunwoongko Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok Firstly, We can start models that doesn't have token type ids like gpt2 or gpt neo.
as you said, It's easier. And next time, we can extend model that uses token type ids like bert, roberta.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say that.

What I suggested is that we work out all the kinks in all models that are unique, and any similar models get worked out afterwards. So if you're finding yourself replicating the same code - it's not unique - if that's a good rough guideline to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I think I misunderstood what you said. (As you know, I'm not good at English because I've never lived in an English-speaking country.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to define a new abstract class or write a utility function than to use the os module directly.

Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't worry about that, even those with great grasp of a certain language misunderstand each other very often. It's normal.

So we will just try to explain again in a different way until we understand each other. And code samples are the best way to communicate at times ;)

So to try again, what I was proposing is to pick all the models that require unique handling. And once those are figured out we can replicate all the similar ones. e.g. we have some models which are 95% identical to 5-10 others.

Alternatively, we can also take a very different approach. We can pick just one model - say T5 or GPT2 and completely port it, and then do a few more models, etc, etc. The drawback in this approach is that it'd be more difficult to see how to generalize, but I think getting one model working sooner is more practical and we can get users to start experimenting and report flaws sooner than later. Also it'll be much easier to see that we are doing the right thing, and of course tests will be needed, and we can test right away.

@sgugger, what's your take on this? Do the integration in stages across all models? Or do 1-2 most popular models, and then replay to other models, generalizing on the way where needed?

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
Expand Down Expand Up @@ -711,6 +715,10 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
Copy link
Contributor

@stas00 stas00 Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do that, I propose we change the earlier code:

- bsz, tgt_len, embed_dim = hidden_states.size()
+ bsz, tgt_len, _ = hidden_states.size()

so that there will be no confusing embed_dim local variable hanging around.

that way we know we already use self.embed_dim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great.


attn_output = self.out_proj(attn_output)

Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,14 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,18 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

if self.rescale_embeddings:
inputs_embeds = inputs_embeds * (self.hidden_size ** 0.5)

token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings

position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/canine/modeling_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,16 @@ def forward(

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self._embed_hash_buckets(
input_ids, self.config.hidden_size, self.config.num_hash_functions, self.config.num_hash_buckets
)

token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings

if self.position_embedding_type == "absolute":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch DeBERTa model. """

import math
import os
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -715,6 +716,9 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=N

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch DeBERTa-v2 model. """

import math
import os
from collections.abc import Sequence

import numpy as np
Expand Down Expand Up @@ -835,6 +836,9 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=N

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/dpr/modeling_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DPR model for Open Domain Question Answering."""


import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -495,6 +494,9 @@ def forward(
)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

outputs = self.ctx_encoder(
input_ids=input_ids,
Expand Down Expand Up @@ -572,6 +574,9 @@ def forward(
)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

outputs = self.question_encoder(
input_ids=input_ids,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,14 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,6 @@ def forward(
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
Expand Down Expand Up @@ -941,7 +939,7 @@ def forward(
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights_reshaped
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/funnel/modeling_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,9 @@ def forward(
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

# TODO: deal with head_mask
if inputs_embeds is None:
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,6 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""PyTorch I-BERT model. """

import math
import os

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -128,11 +129,15 @@ def forward(

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds, inputs_embeds_scaling_factor = self.word_embeddings(input_ids)
else:
inputs_embeds_scaling_factor = None

token_type_embeddings, token_type_embeddings_scaling_factor = self.token_type_embeddings(token_type_ids)

embeddings, embeddings_scaling_factor = self.embeddings_act1(
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import math
import os

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -96,6 +97,9 @@ def forward(

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down Expand Up @@ -790,6 +794,9 @@ def forward(
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/longformer/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,6 @@ def forward(
value_vectors = self.value(hidden_states)

seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"

# normalize query
query_vectors /= math.sqrt(self.head_dim)
Expand Down Expand Up @@ -678,7 +675,7 @@ def forward(
)

assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, self.embed_dim).contiguous()

# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/luke/modeling_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""PyTorch LUKE model. """

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -246,6 +247,9 @@ def forward(

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down Expand Up @@ -296,6 +300,9 @@ def forward(
):
if token_type_ids is None:
token_type_ids = torch.zeros_like(entity_ids)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

entity_embeddings = self.entity_embeddings(entity_ids)
if self.config.entity_emb_size != self.config.hidden_size:
Expand Down Expand Up @@ -900,6 +907,10 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if entity_ids is not None:
entity_seq_length = entity_ids.size(1)
if entity_attention_mask is None:
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down Expand Up @@ -944,6 +947,9 @@ def forward(
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
local_rank = os.getenv("LOCAL_RANK")
if local_rank is not None:
token_type_ids = token_type_ids.to(local_rank)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
Expand Down
Loading