Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] revision of the adapter model can now be specified. #3079

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import os
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable

import huggingface_hub
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config
from transformers.utils.import_utils import is_peft_available
from transformers.utils.peft_utils import find_adapter_config_file

logger = logging.getLogger(__name__)

if TYPE_CHECKING and is_peft_available():
from peft import PeftConfig


def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]:
def wrapper(save_directory: str | Path, **kwargs) -> None:
Expand Down Expand Up @@ -74,8 +77,8 @@ def __init__(
if config_args is None:
config_args = {}

config = self._load_config(model_name_or_path, cache_dir, backend, config_args)
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)
config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)

if max_seq_length is not None and "model_max_length" not in tokenizer_args:
tokenizer_args["model_max_length"] = max_seq_length
Expand All @@ -99,8 +102,21 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]):
"""Loads the configuration of a model"""
def _load_config(
self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]
) -> tuple[PeftConfig | PretrainedConfig, bool]:
"""Loads the transformers or PEFT configuration

Args:
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
or the path to a local model directory.
cache_dir (str | None): The cache directory to store the model configuration.
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config.

Returns:
tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model.
"""
if (
find_adapter_config_file(
model_name_or_path,
Expand All @@ -123,13 +139,39 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
)
from peft import PeftConfig

return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), True

return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False

return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
def _load_model(
self,
model_name_or_path: str,
config: PeftConfig | PretrainedConfig,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
"""Loads the transformers or PEFT model into the `auto_model` attribute

def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
"""Loads the transformer model"""
Args:
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
or the path to a local model directory.
config ("PeftConfig" | PretrainedConfig): The model configuration.
cache_dir (str | None): The cache directory to store the model configuration.
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
is_peft_model (bool): Whether the model is a PEFT model.
model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model.
"""
if backend == "torch":
# When loading a PEFT model, we need to load the base model first,
# but some model_args are only for the adapter
adapter_only_kwargs = {}
if is_peft_model:
for adapter_only_kwarg in ["revision"]:
if adapter_only_kwarg in model_args:
adapter_only_kwargs[adapter_only_kwarg] = model_args.pop(adapter_only_kwarg)

if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config):
Expand All @@ -138,24 +180,26 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args)

if is_peft_model:
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args, **adapter_only_kwargs)
elif backend == "onnx":
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "openvino":
self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
else:
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")

def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if is_peft_available():
from peft import PeftConfig, PeftModel
def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None:
from peft import PeftModel

if isinstance(config, PeftConfig):
self.auto_model = PeftModel.from_pretrained(
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self.auto_model = PeftModel.from_pretrained(
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)

def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_openvino_model(
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
) -> None:
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")

Expand Down Expand Up @@ -210,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar
if export:
self._backend_warn_to_save(model_name_or_path, is_local, backend_name)

def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_onnx_model(
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
) -> None:
try:
import onnxruntime as ort
from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction
Expand Down Expand Up @@ -363,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_
to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`."
logger.warning(to_log)

def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_t5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel

Expand All @@ -372,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) ->
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)

def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_mt5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
"""Loads the encoder model from T5"""
from transformers import MT5EncoderModel

Expand Down
9 changes: 9 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,12 @@ def test_multiple_adapters() -> None:
model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency")
with pytest.raises(ValueError, match="PEFT methods are only supported"):
model.add_adapter(peft_config)


@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test loading PEFT models.")
def test_load_adapter_with_revision():
model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-lora", revision="3b4f75bcb3dec36a7e05da8c44ee2f7f1d023b1a"
)
embeddings = model.encode("Hello, World!")
assert embeddings.shape == (128,)