Skip to content

Commit

Permalink
WIP: invertible adapters support
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 6, 2025
1 parent 7d346db commit 3a9e702
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 18 deletions.
13 changes: 9 additions & 4 deletions src/adapters/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AdapterMethod:
lora = "lora"
prompt_tuning = "prompt_tuning"
reft = "reft"
invertible = "invertible"

@staticmethod
def get_from_config(config) -> List[str]:
Expand All @@ -22,14 +23,18 @@ def get_from_config(config) -> List[str]:
config: The adapter config.
Returns:
str: The adapter type.
List[str]: The adapter type.
"""
methods = []
if getattr(config, "inv_adapter", False):
methods.append(AdapterMethod.invertible)
if config.architecture is None:
return [AdapterMethod.bottleneck]
methods.append(AdapterMethod.bottleneck)
elif config.architecture == "union":
return [AdapterMethod.get_from_config(sub_config) for sub_config in config.configs]
methods.extend([AdapterMethod.get_from_config(sub_config) for sub_config in config.configs])
else:
return [config.architecture]
methods.append(config.architecture)
return methods


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bottleneck import init_bottleneck
from .invertible import init_invertible_adapters
from .lora import init_lora
from .prompt_tuning import init_prompt_tuning
from .reft import init_reft
Expand All @@ -9,4 +10,5 @@
"lora": init_lora,
"prompt_tuning": init_prompt_tuning,
"reft": init_reft,
"invertible": init_invertible_adapters,
}
94 changes: 94 additions & 0 deletions src/adapters/methods/invertible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import types
from functools import partial

import torch
import torch.nn as nn

from ..configuration.adapter_config import BnConfig
from ..utils import multigetattr
from .adapter_layer_base import AdapterLayerBase
from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock


class InvertibleAdapterLayer(AdapterLayerBase, nn.ModuleDict):
adapter_modules_name = "_modules"

def __init__(self, model_config, adapters_config):
super().__init__()
self.location_key = "inv_adapter"
self.model_config = model_config
self.adapters_config = adapters_config

def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.layer_idx = layer_idx
embedding_size = getattr(self.model_config, "embedding_size", self.model_config.hidden_size)
adapter_config = self.adapters_config.match(
adapter_name,
config_type=BnConfig,
location_key="inv_adapter",
)
if adapter_config is not None and adapter_config["inv_adapter"]:
if adapter_config["inv_adapter"] == "nice":
inv_adap = NICECouplingBlock(
[[embedding_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
elif adapter_config["inv_adapter"] == "glow":
inv_adap = GLOWCouplingBlock(
[[embedding_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
else:
raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.")
self[adapter_name] = inv_adap
self[adapter_name].apply(Adapter.init_bert_weights)
return True

return False

def get_invertible_adapter(self):
# HACK: returns the first adapter of the currently active setup. for backwards compatibility
adapter_setup = self.get_active_setup()
if adapter_setup is not None and len(adapter_setup) > 0:
first_adapter = adapter_setup.first()
if first_adapter in self:
return self[first_adapter]
return None

def forward(self, hidden_states: torch.Tensor, rev=False):
adapter_setup = self.get_active_setup()
if adapter_setup is not None and len(adapter_setup) > 0:
first_adapter = adapter_setup.first()
if first_adapter in self:
hidden_states = self[first_adapter](hidden_states, rev=rev)
return hidden_states


def hook_fn(model, module, args, embedding_output):
embedding_output = model.invertible_adapters(embedding_output)
return embedding_output


def init_invertible_adapters(model):
if not hasattr(model, "invertible_adapters"):
model.invertible_adapters = InvertibleAdapterLayer(model.config, model.adapters_config)

embed_layer = multigetattr(model, model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, model))

# Add methods from original invertible adapter mixin.
# This is primarily for backwards compatibility and internal use.
model.add_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), model
)
model.delete_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), model
)
model.get_invertible_adapter = types.MethodType(
lambda self: self.invertible_adapters.get_invertible_adapter(), model
)
model.invertible_adapters_forward = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), model
)
15 changes: 8 additions & 7 deletions src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ def _attn_mask_hook_fn(module, args):


def init_prompt_tuning(model):
model.support_prompt_tuning = True
model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings())
embed_layer = multigetattr(model, model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, model))

for _, layer in model.iter_layers():
layer.register_forward_pre_hook(_attn_mask_hook_fn)
if not hasattr(model, "prompt_tuning"):
model.support_prompt_tuning = True
model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings())
embed_layer = multigetattr(model, model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, model))

for _, layer in model.iter_layers():
layer.register_forward_pre_hook(_attn_mask_hook_fn)
10 changes: 6 additions & 4 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool:
supported.append(_type in self.base_model.adapter_interface.adapter_types)
elif _type == AdapterMethod.prompt_tuning:
supported.append(self.base_model.support_prompt_tuning)
elif _type == AdapterMethod.invertible:
supported.append(
isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin)
)
else:
supported.append(True)
return all(supported)
Expand Down Expand Up @@ -1077,12 +1081,10 @@ def get_adapter(self, name) -> dict:
# global weights are saved at index -1
if name in self.base_model.shared_parameters:
destination[-1]["shared"] = self.base_model.shared_parameters[name]
if (
isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin)
) and name in self.invertible_adapters:
if self.supports_adapter("invertible") and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]

if self.support_prompt_tuning:
if self.supports_adapter("prompt_tuning"):
prompt_tuning = self.prompt_tuning.get_adapter(name)
if prompt_tuning is not None:
destination[-1]["prompt"] = prompt_tuning
Expand Down
2 changes: 1 addition & 1 deletion tests/test_custom_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CustomInterfaceModelTestBase(AdapterTestBase):
)
tokenizer_name = "yujiepan/gemma-2-tiny-random"
adapter_interface = AdapterModelInterface(
adapter_types=["bottleneck", "lora", "reft"],
adapter_types=["bottleneck", "lora", "reft", "invertible"],
model_embeddings="embed_tokens",
model_layers="layers",
layer_self_attn="self_attn",
Expand Down
19 changes: 17 additions & 2 deletions tests/test_custom_interface_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CustomInterfaceCompatTest(unittest.TestCase):
pad_token_id=0,
)
llama_interface = AdapterModelInterface(
adapter_types=["bottleneck", "lora", "reft"],
adapter_types=["bottleneck", "lora", "reft", "invertible"],
model_embeddings="embed_tokens",
model_layers="layers",
layer_self_attn="self_attn",
Expand All @@ -53,7 +53,7 @@ class CustomInterfaceCompatTest(unittest.TestCase):
layer_ln_2=None,
)
bert_interface = AdapterModelInterface(
adapter_types=["bottleneck", "lora", "reft", "prompt_tuning"],
adapter_types=["bottleneck", "lora", "reft", "prompt_tuning", "invertible"],
model_embeddings="embeddings",
model_layers="encoder.layer",
layer_self_attn="attention",
Expand Down Expand Up @@ -94,6 +94,13 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class):
llama_interface,
AutoModelForCausalLM,
),
(
"BnSeqInv_Llama",
adapters.SeqBnInvConfig(),
llama_config,
llama_interface,
AutoModelForCausalLM,
),
(
"BnSeqPreLN_Llama",
adapters.SeqBnConfig(original_ln_before=True),
Expand Down Expand Up @@ -124,6 +131,14 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class):
AutoModel,
bert_bn_rewrites,
),
(
"BnSeqInv_BERT",
adapters.SeqBnInvConfig(),
bert_config,
bert_interface,
AutoModel,
bert_bn_rewrites,
),
(
"BnSeqPreLN_BERT",
adapters.SeqBnConfig(original_ln_before=True),
Expand Down

0 comments on commit 3a9e702

Please sign in to comment.