Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for layer replication in LoRA #1368

Merged
merged 8 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ The default LoRA settings in PEFT add trainable weights to the query and value l
config = LoraConfig(target_modules="all-linear", ...)
```

### Memory efficient Layer Replication with LoRA

An approach used to improve the performance of models is to expand a model by duplicating layers in the model to build a larger model from a pretrained model of a given size. For example increasing a 7B model to a 10B model as described in the [SOLAR](https://arxiv.org/abs/2312.15166) paper. PEFT LoRA supports this kind of expansion in a memory efficient manner that supports further fine-tuning using LoRA adapters attached to the layers post replication of the layers. The replicated layers do not take additional memory as they share the underlying weights so the only additional memory required is the memory for the adapter weights. To use this feature you would create a config with the `layer_replication` argument.

```py
config = LoraConfig(layer_replication=[[0,4], [2,5]], ...)
```

Assuming the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a model with 7 layers arranged as `[0, 1, 2, 3, 2, 3, 4]`. This follows the [mergekit](https://github.com/arcee-ai/mergekit) pass through merge convention where sequences of layers specified as start inclusive and end exclusive tuples are stacked to build the final model. Each layer in the final model gets its own distinct set of LoRA adpaters.

[Fewshot-Metamath-OrcaVicuna-Mistral-10B](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B) is an example of a model trained using this method on Mistral-7B expanded to 10B. The
(adapter_config.json)[https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B/blob/main/adapter_config.json] shows a sample LoRA adapter config applying this method for fine-tuning.

## Merge adapters

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
Expand Down
24 changes: 24 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class LoraConfig(PeftConfig):
ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than
pure LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
layer_replication(`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
all have separate LoRA adapters attached to them.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -244,6 +248,26 @@ class LoraConfig(PeftConfig):
)
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[list[tuple[int, int]]] = field(
default=None,
metadata={
"help": (
"This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. "
"The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with "
"a module list in the model which it modifies to expand the number of modules. "
"Base weights are shared so the memory usage is close to the original model. The intended use is these base weights "
"remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via "
"the adapter layers fit during fine tuning."
"The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n"
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
" Original model has 5 layers labelled by their position in the model: `[0, 1, 2, 3, 4]`\n"
" layer_replication: `[[0, 4], [2, 5]]`\n"
" Final model will have this arrangement of original layers: `[0, 1, 2, 3, 2, 3, 4]`\n"
"This format is based on what is used for pass-through merges in mergekit. It makes it simple to select sequential "
"ranges of a model and stack them while reusing layers at either end of each sequence."
)
},
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down
34 changes: 31 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from tqdm import tqdm

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
check_target_module_exists,
onload_layer,
replicate_layers,
)
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
Expand Down Expand Up @@ -129,6 +135,19 @@ def _check_new_adapter_config(self, config: LoraConfig) -> None:
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.

Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model (`nn.Module`):
The model that is going to be adapted.
"""
if peft_config.layer_replication:
replicate_layers(model, peft_config.layer_replication)

def _create_and_replace(
self,
lora_config,
Expand Down Expand Up @@ -333,6 +352,16 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _check_merge_allowed(self):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
"""Verify that the configuration supports merging.

Currently gptq quantization and replicated layers do not support merging.
"""
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
raise ValueError("Cannot merge LORA layers when base model layers are replicated")

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
Expand All @@ -351,8 +380,7 @@ def _unload_and_optionally_merge(
adapter_names: Optional[list[str]] = None,
):
if merge:
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
self._check_merge_allowed()

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
Expand Down
105 changes: 100 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import copy
import logging
import re
import warnings
Expand Down Expand Up @@ -170,13 +171,27 @@ def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -
Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example.

Args:
peft_config (`str`):
peft_config (`PeftConfig`):
The adapter config.
model_config (`str`):
model_config (`dict`):
The transformers model config, that config should contain the `model_type` key.
"""
...

def _prepare_model(self, peft_config: PeftConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.

See `peft.tuner.lora.LoraModel._prepare_model` for an example.

Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model (`nn.Module`):
The model that is going to be adapted.
"""
pass

@abstractmethod
def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool:
r"""
Expand Down Expand Up @@ -242,6 +257,13 @@ def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
pass

def _check_merge_allowed(self):
"""Helper method to check whether the adapter can be merged.

Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass

def inject_adapter(self, model: nn.Module, adapter_name: str):
r"""
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
Expand All @@ -261,9 +283,6 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False

Expand All @@ -273,6 +292,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):

peft_config = self._prepare_adapter_config(peft_config, model_config)

self._prepare_model(peft_config, model)
is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)

Expand Down Expand Up @@ -337,6 +360,7 @@ def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
self._check_merge_allowed()
for module in self.model.modules():
if isinstance(module, BaseTunerLayer):
with onload_layer(module):
Expand Down Expand Up @@ -665,3 +689,74 @@ def check_adapters_to_merge(module: BaseTunerLayer, adapter_names: Optional[list
warnings.warn("All adapters are already merged, nothing to do.")

return adapter_names


def clone_module(module: nn.Module, share_weights=False):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
"""Clone a module in a pytorch model.

Clones a module of a model, optionally sharing all the parameters between the original and the clone. Simplifies
reusing a module when manipulating the architecture of a model.
"""
clone = copy.deepcopy(module)

def _share_weights(src: nn.Module, dst: nn.Module):
for name, param in src.named_parameters(recurse=False):
dst.register_parameter(name, param)

if share_weights:
for name, submodule in module.named_modules():
_share_weights(submodule, clone.get_submodule(name))

return clone


def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]):
"""Replicate layers in a transfomer model with weight sharing.

This function looks for a module list attribute at model[(.model)*].layers and replicates the layers in the module
list according to the layer map. For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3,
4]` and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`.
"""
while hasattr(model, "model"):
model = model.model
# Some variants of the bert model nest the main model under the bert attribute.
if hasattr(model, "bert"):
model = model.bert

model_type = None
layers: nn.ModuleList = None
if hasattr(model, "layers"):
model_type = "llama"
layers = model.layers
elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
model_type = "bert"
layers = model.encoder.layer
elif hasattr(model, "h"):
model_type = "falcon"
layers = model.h
if not model_type or not isinstance(layers, nn.ModuleList):
raise ValueError(
"Could not locate the layers attribute in the model. "
"Expected Llama, Bert or Falcon compatible architectures."
)

new_layers = []
for start, end in layer_map:
for i in range(start, end):
current_idx = len(new_layers)
new_layers.append(clone_module(layers[i], share_weights=True))
# This is a hack needed to work around the layer_idx introduced in HF transformers.
for submodule in new_layers[-1].modules():
if hasattr(submodule, "layer_idx"):
submodule.layer_idx = current_idx
layers = nn.ModuleList(new_layers)
if model_type == "llama":
model.layers = layers
elif model_type == "bert":
model.encoder.layer = layers
elif model_type == "falcon":
model.h = layers
else:
raise ValueError("Unexpected model type, need to handle post-processing of layers.")
if hasattr(model.config, "num_hidden_layers"): # Common to Llama, Bert, Falcon.
model.config.num_hidden_layers = len(new_layers)
41 changes: 40 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import AdaLoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model
from peft import AdaLoraConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model

from .testing_common import PeftCommonTester, PeftTestConfigManager

Expand Down Expand Up @@ -302,3 +302,42 @@ def test_generate_adalora_no_dropout(self):
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs):
self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs)

def test_lora_layer_replication(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
config_kwargs = {
"target_modules": ["down_proj", "up_proj"],
"task_type": "CAUSAL_LM",
"lora_dropout": 0.0,
"layer_replication": [[0, 1], [0, 2], [1, 2]],
}
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = LoraConfig(
base_model_name_or_path=model_id,
**config_kwargs,
)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
assert len(model.model.layers), "Expected 2 layers in original model." == 2
model = get_peft_model(model, config)
layers = model.base_model.model.model.layers
assert len(layers) == 4, "Expected 4 layers in adapted model."
assert (
layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
== layers[1].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
and layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
== layers[3].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
), "Expected layers 0-1 and 2-3 to share weights"
assert (
layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
!= layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
), "Expected layers 0 and 2 to have different weights"
assert (
layers[0].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr()
!= layers[1].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr()
and layers[2].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr()
!= layers[3].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr()
), "Expected all LoRA adapters to have distinct weights"
assert (
len([n for n, _ in model.named_parameters() if ".lora_A." in n]) == 8
), "Expected 8 LoRA adapters since we are adding one each for up and down."
self._test_prepare_for_training(model_id, LoraConfig, config_kwargs)
self._test_generate(model_id, LoraConfig, config_kwargs)
Loading