Skip to content

Commit

Permalink
Allow only textual inputs to VisualBert (#13687)
Browse files Browse the repository at this point in the history
  • Loading branch information
gchhablani authored Sep 22, 2021
1 parent 93624bf commit 50c746e
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/transformers/models/visual_bert/modeling_visual_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,29 +778,30 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if visual_embeds is None:
raise ValueError(
f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."
)

batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

visual_input_shape = visual_embeds.size()[:-1]
if visual_embeds is not None:
visual_input_shape = visual_embeds.size()[:-1]

if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)

if visual_attention_mask is None:
if visual_embeds is not None and visual_attention_mask is None:
visual_attention_mask = torch.ones(visual_input_shape, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)

combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down

0 comments on commit 50c746e

Please sign in to comment.