Skip to content

Commit

Permalink
ENH Remove redundant initialization layer calls (#887)
Browse files Browse the repository at this point in the history
This should lead to a big speedup when initializing LoRA layers.

---------

Co-authored-by: poedator <[email protected]>
  • Loading branch information
BenjaminBossan and poedator authored Sep 6, 2023
1 parent 20d9c17 commit 08368a1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
14 changes: 8 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

weight = getattr(self, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(weight.device)

def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
Expand Down Expand Up @@ -138,17 +142,15 @@ def __init__(
**kwargs,
) -> None:
init_lora_weights = kwargs.pop("init_lora_weights", True)
# this gets the init from nn.Linear's super perspective, i.e.
# nn.Module.__init__, which should always be called
super(nn.Linear, self).__init__()

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T

nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.is_target_conv_1d_layer = is_target_conv_1d_layer
Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ def _create_and_replace(
@staticmethod
def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
new_module.weight = child.weight
if hasattr(child, "bias"):
if child.bias is not None:
new_module.bias = child.bias
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
new_module.state = child.state
Expand Down
14 changes: 14 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,17 @@ def run_with_disable(config_kwargs, bias):
@parameterized.expand(TEST_CASES)
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)


class TestRepr(unittest.TestCase):
"""Tests related to the repr of adapted models"""

def test_repr_lora(self):
config = LoraConfig(target_modules=["lin0"])
model = get_peft_model(MLP(), config)
print_output = repr(model.model.lin0)
self.assertTrue(print_output.startswith("Linear"))
self.assertTrue("in_features=10, out_features=20" in print_output)
self.assertTrue("lora_A" in print_output)
self.assertTrue("lora_B" in print_output)
self.assertTrue("default" in print_output)

0 comments on commit 08368a1

Please sign in to comment.