Skip to content

Commit

Permalink
Add test and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Feb 12, 2024
1 parent 8ff105d commit eecb76b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
6 changes: 5 additions & 1 deletion src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class LoraConfig(PeftConfig):
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
quantized model in this case, as LoftQ will quantize the model itself.
layer_replication(`List[Tuple[int, int]]):
Build a new stack of layers by stacking the original model layers according to the ranges specified.
This allows expanding (or shrinking) the model without duplicating the base model weights.
The new layers will all have separate LoRA adapters attached to them.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -229,7 +233,7 @@ class LoraConfig(PeftConfig):
default=None,
metadata={
"help": (
"This enable using LoRA to effectively expand a model to a larger size by repeating some layers. "
"This enables using LoRA to effectively expand a model to a larger size by repeating some layers. "
"Base weights are shared so the memory usage is close to the original model."
"The format is a list of (start, end) pairs which specify the layer ranges to stack."
)
Expand Down
23 changes: 8 additions & 15 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from tqdm import tqdm

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, clone_module, onload_layer
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
check_target_module_exists,
onload_layer,
replicate_layers
)
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
Expand Down Expand Up @@ -131,27 +137,14 @@ def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.
Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model_config (`nn.Module`):
The model that is going to be adapted.
"""
if peft_config.layer_replication:
new_layers = []
for start, end in peft_config.layer_replication:
for i in range(start, end):
current_idx = len(new_layers)
new_layers.append(clone_module(model.base_model.layers[i], share_weights=True))
# This is a hack needed to work around the layer_idx introduced in HF transformers.
for submodule in new_layers[-1].modules():
if hasattr(submodule, 'layer_idx'):
submodule.layer_idx = current_idx
model.base_model.layers = nn.ModuleList(new_layers)
if hasattr(model.config, 'num_hidden_layers'):
model.config.num_hidden_layers = len(new_layers)
replicate_layers(model, peft_config.layer_replication)

def _create_and_replace(
self,
Expand Down
19 changes: 17 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import torch
from accelerate.hooks import AlignDevicesHook
Expand Down Expand Up @@ -182,7 +182,7 @@ def _prepare_model(self, peft_config: PeftConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.
See `peft.tuner.lora.LoraModel._prepare_model` for an example.
Args:
peft_config (`PeftConfig`):
Expand Down Expand Up @@ -695,3 +695,18 @@ def _share_weights(src: nn.Module, dst: nn.Module):
_share_weights(submodule, clone.get_submodule(name))

return clone


def replicate_layers(model: nn.Module, layer_map: List[Tuple[int, int]]):
new_layers = []
for start, end in layer_map:
for i in range(start, end):
current_idx = len(new_layers)
new_layers.append(clone_module(model.base_model.layers[i], share_weights=True))
# This is a hack needed to work around the layer_idx introduced in HF transformers.
for submodule in new_layers[-1].modules():
if hasattr(submodule, 'layer_idx'):
submodule.layer_idx = current_idx
model.base_model.layers = nn.ModuleList(new_layers)
if hasattr(model.config, 'num_hidden_layers'):
model.config.num_hidden_layers = len(new_layers)
21 changes: 20 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import AdaLoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model
from peft import AdaLoraConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model

from .testing_common import PeftCommonTester, PeftTestConfigManager

Expand Down Expand Up @@ -302,3 +302,22 @@ def test_generate_adalora_no_dropout(self):
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs):
self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs)

def test_lora_layer_replication(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
config_kwargs = {
"target_modules": ['down_proj', 'up_proj'],
"task_type": "CAUSAL_LM",
"lora_dropout": 0.0,
"layer_replication": [[0, 1], [0, 2], [1, 2]]
}
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = LoraConfig(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
self.assertEquals(4, len(model.base_model.model.model.layers), 'Expected 4 layers in adapted model.')
self.assertEquals(8, len([n for n, _ in model.named_parameters() if '.lora_A.' in n]))
self._test_prepare_for_training(model_id, LoraConfig, config_kwargs)
self._test_generate(model_id, LoraConfig, config_kwargs)

0 comments on commit eecb76b

Please sign in to comment.