Skip to content

Commit

Permalink
Fix prompt tuning training for LM & tagging flex heads
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Nov 6, 2023
1 parent 5ec546c commit 291c265
Show file tree
Hide file tree
Showing 26 changed files with 132 additions and 33 deletions.
14 changes: 12 additions & 2 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class ForwardContext:
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
"prefix_attention_mask_length",
]
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
Expand All @@ -107,6 +108,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
Expand All @@ -121,7 +124,14 @@ def wrapper_func(self, *args, **kwargs):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
return results

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
else:
return f(self, *args, **kwargs)

Expand Down
28 changes: 27 additions & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = kwargs.pop("labels", None)
if labels is not None:
loss_fct = CrossEntropyLoss()
# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)
if attention_mask is not None:
attention_mask = torch.cat(
(torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask),
dim=-1,
)

# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
Expand Down Expand Up @@ -752,7 +765,14 @@ def _get_used_heads(self, head_name: str = None):
return head_modules

def forward_head(
self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
self,
all_outputs,
head_name=None,
cls_output=None,
attention_mask=None,
return_dict=False,
context=None,
**kwargs
):
"""
The forward pass through a prediction head configuration. There are three ways to specify the used prediction
Expand Down Expand Up @@ -800,6 +820,12 @@ def _get_head_input(outputs, cls_out, batch):
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

if isinstance(self.active_head, BatchSplit):
if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]:
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn

from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = labels[..., 1:].contiguous()
else:
logits_for_loss = lm_logits

# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)

loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))

if return_dict:
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,6 @@ def forward(self, hidden_states: torch.Tensor):

context = ForwardContext.get_context()
if context is not None:
context.prefix_attention_mask_length = prefix_attention_mask_length
context.prompt_tokens_length = prefix_attention_mask_length

return hidden_states
6 changes: 2 additions & 4 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from os.path import join
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ModelOutput

from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig, ModelAdaptersConfig
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfigBase, AdapterFusionConfig, BnConfig
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
Expand Down Expand Up @@ -955,7 +954,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False)
context.adapter_gating_scores = defaultdict(dict)
context.adapter_fusion_attentions = defaultdict(dict)
context.prefix_attention_mask_length = kwargs.get("output_prefix_attention_mask_length", None)

def get_fusion_regularization_loss(self):
reg_loss = None
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.albert(
outputs, context = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -77,7 +77,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs = self.model(
outputs, context = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
Expand All @@ -95,7 +95,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.beit(
outputs, context = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
Expand All @@ -57,7 +57,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
outputs, context = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -79,7 +79,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
outputs, context = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -78,7 +78,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/clip/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(
output_adapter_fusion_attentions=False,
**kwargs
):
outputs = self.clip(
outputs, context = self.clip(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
Expand All @@ -56,7 +56,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/deberta/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.deberta(
outputs, context = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -69,7 +69,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/deberta_v2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.deberta(
outputs, context = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -72,7 +72,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(
else None
)

distilbert_output = self.distilbert(
distilbert_output, context = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
Expand All @@ -96,7 +96,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

outputs = self.forward_head(
distilbert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs
Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/electra/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.electra(
outputs, context = self.electra(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -79,7 +79,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

head_inputs = outputs

Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/gpt2/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
outputs, context = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
Expand All @@ -85,7 +85,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

batch_size = outputs[0].shape[0]

Expand Down
5 changes: 4 additions & 1 deletion src/adapters/models/gptj/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
outputs, context = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
Expand All @@ -81,7 +81,10 @@ def forward(
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

batch_size = outputs[0].shape[0]

Expand Down
Loading

0 comments on commit 291c265

Please sign in to comment.