From 50c746eeb71f7b8f95a264b09249c9555cdd2e17 Mon Sep 17 00:00:00 2001 From: Gunjan Chhablani Date: Wed, 22 Sep 2021 21:21:53 +0530 Subject: [PATCH] Allow only textual inputs to VisualBert (#13687) --- .../visual_bert/modeling_visual_bert.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index c6c01010081ece..85b3c75781ade1 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -778,29 +778,30 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if visual_embeds is None: - raise ValueError( - f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model." - ) - batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - visual_input_shape = visual_embeds.size()[:-1] + if visual_embeds is not None: + visual_input_shape = visual_embeds.size()[:-1] if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) - if visual_attention_mask is None: + if visual_embeds is not None and visual_attention_mask is None: visual_attention_mask = torch.ones(visual_input_shape, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. + if visual_embeds is not None: + combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + combined_attention_mask, [batch_size, input_shape + visual_input_shape], device + ) - combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1) - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - combined_attention_mask, [batch_size, input_shape + visual_input_shape], device - ) + else: + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, [batch_size, input_shape], device + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head