From 3a9e702091ae53fb6ff370f6ec846c1c8ce7a7ed Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 6 Jan 2025 22:37:57 +0100 Subject: [PATCH] WIP: invertible adapters support --- src/adapters/interface.py | 13 ++-- src/adapters/methods/__init__.py | 2 + src/adapters/methods/invertible.py | 94 +++++++++++++++++++++++++++ src/adapters/methods/prompt_tuning.py | 15 +++-- src/adapters/model_mixin.py | 10 +-- tests/test_custom_interface.py | 2 +- tests/test_custom_interface_compat.py | 19 +++++- 7 files changed, 137 insertions(+), 18 deletions(-) create mode 100644 src/adapters/methods/invertible.py diff --git a/src/adapters/interface.py b/src/adapters/interface.py index 2d7db7906..7f094bbc1 100644 --- a/src/adapters/interface.py +++ b/src/adapters/interface.py @@ -12,6 +12,7 @@ class AdapterMethod: lora = "lora" prompt_tuning = "prompt_tuning" reft = "reft" + invertible = "invertible" @staticmethod def get_from_config(config) -> List[str]: @@ -22,14 +23,18 @@ def get_from_config(config) -> List[str]: config: The adapter config. Returns: - str: The adapter type. + List[str]: The adapter type. """ + methods = [] + if getattr(config, "inv_adapter", False): + methods.append(AdapterMethod.invertible) if config.architecture is None: - return [AdapterMethod.bottleneck] + methods.append(AdapterMethod.bottleneck) elif config.architecture == "union": - return [AdapterMethod.get_from_config(sub_config) for sub_config in config.configs] + methods.extend([AdapterMethod.get_from_config(sub_config) for sub_config in config.configs]) else: - return [config.architecture] + methods.append(config.architecture) + return methods @dataclass diff --git a/src/adapters/methods/__init__.py b/src/adapters/methods/__init__.py index 650b015f9..5082ff0b6 100644 --- a/src/adapters/methods/__init__.py +++ b/src/adapters/methods/__init__.py @@ -1,4 +1,5 @@ from .bottleneck import init_bottleneck +from .invertible import init_invertible_adapters from .lora import init_lora from .prompt_tuning import init_prompt_tuning from .reft import init_reft @@ -9,4 +10,5 @@ "lora": init_lora, "prompt_tuning": init_prompt_tuning, "reft": init_reft, + "invertible": init_invertible_adapters, } diff --git a/src/adapters/methods/invertible.py b/src/adapters/methods/invertible.py new file mode 100644 index 000000000..5d10a6fe5 --- /dev/null +++ b/src/adapters/methods/invertible.py @@ -0,0 +1,94 @@ +import types +from functools import partial + +import torch +import torch.nn as nn + +from ..configuration.adapter_config import BnConfig +from ..utils import multigetattr +from .adapter_layer_base import AdapterLayerBase +from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock + + +class InvertibleAdapterLayer(AdapterLayerBase, nn.ModuleDict): + adapter_modules_name = "_modules" + + def __init__(self, model_config, adapters_config): + super().__init__() + self.location_key = "inv_adapter" + self.model_config = model_config + self.adapters_config = adapters_config + + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + self.layer_idx = layer_idx + embedding_size = getattr(self.model_config, "embedding_size", self.model_config.hidden_size) + adapter_config = self.adapters_config.match( + adapter_name, + config_type=BnConfig, + location_key="inv_adapter", + ) + if adapter_config is not None and adapter_config["inv_adapter"]: + if adapter_config["inv_adapter"] == "nice": + inv_adap = NICECouplingBlock( + [[embedding_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + elif adapter_config["inv_adapter"] == "glow": + inv_adap = GLOWCouplingBlock( + [[embedding_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + else: + raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.") + self[adapter_name] = inv_adap + self[adapter_name].apply(Adapter.init_bert_weights) + return True + + return False + + def get_invertible_adapter(self): + # HACK: returns the first adapter of the currently active setup. for backwards compatibility + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self: + return self[first_adapter] + return None + + def forward(self, hidden_states: torch.Tensor, rev=False): + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self: + hidden_states = self[first_adapter](hidden_states, rev=rev) + return hidden_states + + +def hook_fn(model, module, args, embedding_output): + embedding_output = model.invertible_adapters(embedding_output) + return embedding_output + + +def init_invertible_adapters(model): + if not hasattr(model, "invertible_adapters"): + model.invertible_adapters = InvertibleAdapterLayer(model.config, model.adapters_config) + + embed_layer = multigetattr(model, model.adapter_interface.model_embeddings) + embed_layer.register_forward_hook(partial(hook_fn, model)) + + # Add methods from original invertible adapter mixin. + # This is primarily for backwards compatibility and internal use. + model.add_invertible_adapter = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), model + ) + model.delete_invertible_adapter = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), model + ) + model.get_invertible_adapter = types.MethodType( + lambda self: self.invertible_adapters.get_invertible_adapter(), model + ) + model.invertible_adapters_forward = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), model + ) diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index 4d451bd3f..03025f64c 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -191,10 +191,11 @@ def _attn_mask_hook_fn(module, args): def init_prompt_tuning(model): - model.support_prompt_tuning = True - model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings()) - embed_layer = multigetattr(model, model.adapter_interface.model_embeddings) - embed_layer.register_forward_hook(partial(hook_fn, model)) - - for _, layer in model.iter_layers(): - layer.register_forward_pre_hook(_attn_mask_hook_fn) + if not hasattr(model, "prompt_tuning"): + model.support_prompt_tuning = True + model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings()) + embed_layer = multigetattr(model, model.adapter_interface.model_embeddings) + embed_layer.register_forward_hook(partial(hook_fn, model)) + + for _, layer in model.iter_layers(): + layer.register_forward_pre_hook(_attn_mask_hook_fn) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 4ffe4c8be..1897aae0c 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -494,6 +494,10 @@ def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool: supported.append(_type in self.base_model.adapter_interface.adapter_types) elif _type == AdapterMethod.prompt_tuning: supported.append(self.base_model.support_prompt_tuning) + elif _type == AdapterMethod.invertible: + supported.append( + isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin) + ) else: supported.append(True) return all(supported) @@ -1077,12 +1081,10 @@ def get_adapter(self, name) -> dict: # global weights are saved at index -1 if name in self.base_model.shared_parameters: destination[-1]["shared"] = self.base_model.shared_parameters[name] - if ( - isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin) - ) and name in self.invertible_adapters: + if self.supports_adapter("invertible") and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] - if self.support_prompt_tuning: + if self.supports_adapter("prompt_tuning"): prompt_tuning = self.prompt_tuning.get_adapter(name) if prompt_tuning is not None: destination[-1]["prompt"] = prompt_tuning diff --git a/tests/test_custom_interface.py b/tests/test_custom_interface.py index 9e96b8a7b..70eb4054e 100644 --- a/tests/test_custom_interface.py +++ b/tests/test_custom_interface.py @@ -36,7 +36,7 @@ class CustomInterfaceModelTestBase(AdapterTestBase): ) tokenizer_name = "yujiepan/gemma-2-tiny-random" adapter_interface = AdapterModelInterface( - adapter_types=["bottleneck", "lora", "reft"], + adapter_types=["bottleneck", "lora", "reft", "invertible"], model_embeddings="embed_tokens", model_layers="layers", layer_self_attn="self_attn", diff --git a/tests/test_custom_interface_compat.py b/tests/test_custom_interface_compat.py index 922ececdd..9e21723be 100644 --- a/tests/test_custom_interface_compat.py +++ b/tests/test_custom_interface_compat.py @@ -35,7 +35,7 @@ class CustomInterfaceCompatTest(unittest.TestCase): pad_token_id=0, ) llama_interface = AdapterModelInterface( - adapter_types=["bottleneck", "lora", "reft"], + adapter_types=["bottleneck", "lora", "reft", "invertible"], model_embeddings="embed_tokens", model_layers="layers", layer_self_attn="self_attn", @@ -53,7 +53,7 @@ class CustomInterfaceCompatTest(unittest.TestCase): layer_ln_2=None, ) bert_interface = AdapterModelInterface( - adapter_types=["bottleneck", "lora", "reft", "prompt_tuning"], + adapter_types=["bottleneck", "lora", "reft", "prompt_tuning", "invertible"], model_embeddings="embeddings", model_layers="encoder.layer", layer_self_attn="attention", @@ -94,6 +94,13 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class): llama_interface, AutoModelForCausalLM, ), + ( + "BnSeqInv_Llama", + adapters.SeqBnInvConfig(), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), ( "BnSeqPreLN_Llama", adapters.SeqBnConfig(original_ln_before=True), @@ -124,6 +131,14 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class): AutoModel, bert_bn_rewrites, ), + ( + "BnSeqInv_BERT", + adapters.SeqBnInvConfig(), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), ( "BnSeqPreLN_BERT", adapters.SeqBnConfig(original_ln_before=True),