Skip to content

Commit

Permalink
FEAT Integrate X-LoRA (#1491)
Browse files Browse the repository at this point in the history
Implements X-LoRA: Mixture of Low-Rank Adapter Experts
Paper: https://arxiv.org/abs/2402.07148
  • Loading branch information
EricLBuehler authored Jul 5, 2024
1 parent 01f1b99 commit 58afb34
Show file tree
Hide file tree
Showing 15 changed files with 1,313 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
LNTuningModel,
VeraConfig,
VeraModel,
XLoraConfig,
XLoraModel,
)
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
Expand Down
5 changes: 5 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import torch

from peft.tuners.xlora.model import XLoraModel

from .config import PeftConfig
from .mixed_model import PeftMixedModel
from .peft_model import (
Expand Down Expand Up @@ -56,6 +58,7 @@
PromptTuningConfig,
VeraConfig,
VeraModel,
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner as _BaseTuner
from .utils import _prepare_prompt_learning_config
Expand Down Expand Up @@ -90,6 +93,7 @@
"POLY": PolyConfig,
"LN_TUNING": LNTuningConfig,
"VERA": VeraConfig,
"XLORA": XLoraConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
Expand All @@ -103,6 +107,7 @@
"POLY": PolyModel,
"LN_TUNING": LNTuningModel,
"VERA": VeraModel,
"XLORA": XLoraModel,
}


Expand Down
3 changes: 1 addition & 2 deletions src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from torch import nn
from transformers.utils import PushToHubMixin

from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES

from .config import PeftConfig
from .peft_model import PeftModel
from .tuners import (
Expand All @@ -36,6 +34,7 @@
MixedModel,
OFTModel,
)
from .tuners.mixed import COMPATIBLE_TUNER_TYPES
from .utils import PeftType, _set_adapter, _set_trainable


Expand Down
30 changes: 29 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory, named_module_tensors
from huggingface_hub import ModelCard, ModelCardData, hf_hub_download
from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -55,6 +55,8 @@
PromptEmbedding,
PromptEncoder,
VeraModel,
XLoraConfig,
XLoraModel,
)
from .tuners.tuners_utils import BaseTuner, BaseTunerLayer
from .utils import (
Expand Down Expand Up @@ -91,6 +93,7 @@
PeftType.POLY: PolyModel,
PeftType.LN_TUNING: LNTuningModel,
PeftType.VERA: VeraModel,
PeftType.XLORA: XLoraModel,
}


Expand Down Expand Up @@ -479,13 +482,38 @@ def from_pretrained(
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
else:
config.inference_mode = not is_trainable
if isinstance(getattr(model, "base_model", None), XLoraModel):
if not isinstance(config, XLoraConfig):
raise TypeError(f"Expected 'XLoraConfig', got '{type(config)}' instead.")
if "adapters" in kwargs:
config.adapters = kwargs["adapters"]
else:
# If the path is on HF hub, then we get the adapter names to create a subfolders list which tells
# `load_adapter` where the adapters are.
if not os.path.exists(model_id):
s = HfFileSystem()

# The names of the adapters which must be in folders
adapter_names = [
file["name"][len(model_id) + 1 :] for file in s.ls(model_id) if file["type"] == "directory"
]
# Prepare a dict of adapter paths, which really just point to the hf id; we will use the subfolders
adapter_paths = {}
for adapter_name in adapter_names:
adapter_paths[adapter_name] = os.path.join(model_id, model_id)
config.adapters = adapter_paths
config._subfolders = adapter_names
else:
if "adapters" not in kwargs:
raise ValueError("If model_id is a local path, then `adapters` must be passed in kwargs.")

if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](
model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
)

model.load_adapter(
model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs
)
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
from .poly import PolyConfig, PolyModel
from .ln_tuning import LNTuningConfig, LNTuningModel
from .vera import VeraConfig, VeraModel
from .xlora import XLoraConfig, XLoraModel
22 changes: 19 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES
from peft.utils.peft_types import PeftType

from ..config import PeftConfig
from ..utils import ModulesToSaveWrapper, _get_submodules
Expand Down Expand Up @@ -141,7 +143,12 @@ class BaseTuner(nn.Module, ABC):
double-check that the `config.target_modules` were specified correctly.
"""

def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], adapter_name: str) -> None:
def __init__(
self,
model,
peft_config: Union[PeftConfig, dict[str, PeftConfig]],
adapter_name: str,
) -> None:
super().__init__()

self.model = model
Expand All @@ -164,7 +171,8 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]],

self.active_adapter: str | list[str] = adapter_name
self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
self.inject_adapter(self.model, adapter_name)
if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA:
self.inject_adapter(self.model, adapter_name)

# Copy the peft_config in the injected model.
self.model.peft_config = self.peft_config
Expand Down Expand Up @@ -389,6 +397,11 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

if getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES:
# dummy adapter, we allow not matching any module
key_list = []
is_target_modules_in_base_model = True

# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)

Expand Down Expand Up @@ -417,7 +430,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
parent, target, target_name = _get_submodules(model, key)
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

if not is_target_modules_in_base_model:
# Handle X-LoRA case.
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
Expand Down Expand Up @@ -776,6 +790,8 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from
the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py
"""
if not hasattr(peft_config, "target_modules"):
return peft_config

# if `target_modules` is a string, convert to lower case and check if it matches "all-linear"
if not (
Expand Down
19 changes: 19 additions & 0 deletions src/peft/tuners/xlora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .config import XLoraConfig
from .model import XLoraModel


__all__ = ["XLoraConfig", "XLoraModel"]
Loading

0 comments on commit 58afb34

Please sign in to comment.