Skip to content

Commit

Permalink
Add invertible output layer. Test fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 6, 2025
1 parent 3a9e702 commit 074ca66
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _residual_hook_fn(location_key, module, args):


def init_bottleneck(model):
model = model.base_model
for _, layer in model.iter_layers():
if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None):
if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None):
Expand Down
34 changes: 22 additions & 12 deletions src/adapters/methods/invertible.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,34 @@ def hook_fn(model, module, args, embedding_output):
return embedding_output


def inv_hook_fn(model, module, args):
inv_output = model.invertible_adapters(args[0], rev=True)
return (inv_output,) + args[1:]


def init_invertible_adapters(model):
if not hasattr(model, "invertible_adapters"):
model.invertible_adapters = InvertibleAdapterLayer(model.config, model.adapters_config)
base_model = model.base_model
if not hasattr(base_model, "invertible_adapters"):
base_model.invertible_adapters = InvertibleAdapterLayer(base_model.config, base_model.adapters_config)

embed_layer = multigetattr(model, model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, model))
embed_layer = multigetattr(base_model, base_model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, base_model))

# Add methods from original invertible adapter mixin.
# This is primarily for backwards compatibility and internal use.
model.add_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), model
base_model.add_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), base_model
)
model.delete_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), model
base_model.delete_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), base_model
)
model.get_invertible_adapter = types.MethodType(
lambda self: self.invertible_adapters.get_invertible_adapter(), model
base_model.get_invertible_adapter = types.MethodType(
lambda self: self.invertible_adapters.get_invertible_adapter(), base_model
)
model.invertible_adapters_forward = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), model
base_model.invertible_adapters_forward = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), base_model
)

# Register reverse forward pass
if output_embedding := model.get_output_embeddings():
output_embedding.register_forward_pre_hook(partial(inv_hook_fn, base_model))
1 change: 1 addition & 0 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def T(w):


def init_lora(model):
model = model.base_model
for _, _, attention in model.iter_attentions():
if q_proj := multigetattr(attention, model.adapter_interface.attn_q_proj, None):
lora_proj = LoRALinear.wrap(q_proj, "selfattn", model.config, model.adapters_config, attn_key="q")
Expand Down
1 change: 1 addition & 0 deletions src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _attn_mask_hook_fn(module, args):


def init_prompt_tuning(model):
model = model.base_model
if not hasattr(model, "prompt_tuning"):
model.support_prompt_tuning = True
model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings())
Expand Down
1 change: 1 addition & 0 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def hook_fn(module, args, output):


def init_reft(model):
model = model.base_model
for _, layer in model.iter_layers():
if not hasattr(layer, "reft_layer"):
layer.reft_layer = ReftLayer("output", model.config, model.adapters_config)
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def init_adapters(self, model_config, adapters_config):
if getattr(self.base_model, "adapter_interface", None) is not None:
for adapter_type in self.base_model.adapter_interface.adapter_types:
init_func = METHOD_INIT_MAPPING[adapter_type]
init_func(self.base_model)
init_func(self)
else:
self._default_init_adapter_methods(self.config, self.adapters_config)

Expand Down
2 changes: 2 additions & 0 deletions tests/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def test_forward_bottleneck(self):
def test_invertible_adapter_forward(self):
model = self.get_model()
model.eval()
if not model.supports_adapter("invertible"):
self.skipTest("Model does not support invertible adapters.")

for adapter_config, _ in self.inv_adapter_configs_to_test:
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
Expand Down

0 comments on commit 074ca66

Please sign in to comment.