Skip to content

Commit

Permalink
Add support for layer replication in LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Feb 12, 2024
1 parent 7716dd8 commit 8ff105d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 9 deletions.
13 changes: 12 additions & 1 deletion src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from typing import List, Literal, Optional, Tuple, Union

from peft.config import PeftConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -224,6 +224,17 @@ class LoraConfig(PeftConfig):
)
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[List[Tuple[int, int]]] = field(
default=None,
metadata={
"help": (
"This enable using LoRA to effectively expand a model to a larger size by repeating some layers. "
"Base weights are shared so the memory usage is close to the original model."
"The format is a list of (start, end) pairs which specify the layer ranges to stack."
)
}
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down
41 changes: 38 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tqdm import tqdm

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, clone_module, onload_layer
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
Expand Down Expand Up @@ -127,6 +127,32 @@ def _check_new_adapter_config(self, config: LoraConfig) -> None:
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.
Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model_config (`nn.Module`):
The model that is going to be adapted.
"""
if peft_config.layer_replication:
new_layers = []
for start, end in peft_config.layer_replication:
for i in range(start, end):
current_idx = len(new_layers)
new_layers.append(clone_module(model.base_model.layers[i], share_weights=True))
# This is a hack needed to work around the layer_idx introduced in HF transformers.
for submodule in new_layers[-1].modules():
if hasattr(submodule, 'layer_idx'):
submodule.layer_idx = current_idx
model.base_model.layers = nn.ModuleList(new_layers)
if hasattr(model.config, 'num_hidden_layers'):
model.config.num_hidden_layers = len(new_layers)

def _create_and_replace(
self,
lora_config,
Expand Down Expand Up @@ -327,6 +353,16 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _check_merge_allowed(self):
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get('layer_replication'):
raise ValueError("Cannot merge LORA layers when base model layers are replicated")

def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
self._check_merge_allowed()
super().merge_adapter(adapter_names=adapter_names)

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
Expand All @@ -345,8 +381,7 @@ def _unload_and_optionally_merge(
adapter_names: Optional[list[str]] = None,
):
if merge:
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
self._check_merge_allowed()

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
Expand Down
40 changes: 35 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import copy
import logging
import re
import warnings
Expand Down Expand Up @@ -170,13 +171,27 @@ def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.
Args:
peft_config (`str`):
peft_config (`PeftConfig`):
The adapter config.
model_config (`str`):
model_config (`dict`):
The transformers model config, that config should contain the `model_type` key.
"""
...

def _prepare_model(self, peft_config: PeftConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.
Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model_config (`nn.Module`):
The model that is going to be adapted.
"""
pass

@abstractmethod
def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool:
r"""
Expand Down Expand Up @@ -261,9 +276,6 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False

Expand All @@ -273,6 +285,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):

peft_config = self._prepare_adapter_config(peft_config, model_config)

self._prepare_model(peft_config, model)
is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)

Expand Down Expand Up @@ -665,3 +681,17 @@ def check_adapters_to_merge(module: BaseTunerLayer, adapter_names: Optional[list
warnings.warn("All adapters are already merged, nothing to do.")

return adapter_names


def clone_module(module: nn.Module, share_weights=False):
clone = copy.deepcopy(module)

def _share_weights(src: nn.Module, dst: nn.Module):
for name, param in src.named_parameters(recurse=False):
dst.register_parameter(name, param)

if share_weights:
for name, submodule in module.named_modules():
_share_weights(submodule, clone.get_submodule(name))

return clone

0 comments on commit 8ff105d

Please sign in to comment.