Skip to content

Commit

Permalink
[fix] revision of the adapter model can now be specified. (#3079)
Browse files Browse the repository at this point in the history
* add: revision of the adapter model can now be specified.

* Refactor loading PEFT slightly to support 'revision'

* Update the lacking type-hinting in the Transformer module

---------

Co-authored-by: ryoji.nagata <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2024
1 parent df6a8e8 commit a542b0a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 22 deletions.
90 changes: 68 additions & 22 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import os
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable

import huggingface_hub
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config
from transformers.utils.import_utils import is_peft_available
from transformers.utils.peft_utils import find_adapter_config_file

logger = logging.getLogger(__name__)

if TYPE_CHECKING and is_peft_available():
from peft import PeftConfig


def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]:
def wrapper(save_directory: str | Path, **kwargs) -> None:
Expand Down Expand Up @@ -74,8 +77,8 @@ def __init__(
if config_args is None:
config_args = {}

config = self._load_config(model_name_or_path, cache_dir, backend, config_args)
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)
config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)

if max_seq_length is not None and "model_max_length" not in tokenizer_args:
tokenizer_args["model_max_length"] = max_seq_length
Expand All @@ -99,8 +102,21 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]):
"""Loads the configuration of a model"""
def _load_config(
self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]
) -> tuple[PeftConfig | PretrainedConfig, bool]:
"""Loads the transformers or PEFT configuration
Args:
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
or the path to a local model directory.
cache_dir (str | None): The cache directory to store the model configuration.
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config.
Returns:
tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model.
"""
if (
find_adapter_config_file(
model_name_or_path,
Expand All @@ -123,13 +139,39 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
)
from peft import PeftConfig

return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), True

return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False

return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
def _load_model(
self,
model_name_or_path: str,
config: PeftConfig | PretrainedConfig,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
"""Loads the transformers or PEFT model into the `auto_model` attribute
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
"""Loads the transformer model"""
Args:
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
or the path to a local model directory.
config ("PeftConfig" | PretrainedConfig): The model configuration.
cache_dir (str | None): The cache directory to store the model configuration.
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
is_peft_model (bool): Whether the model is a PEFT model.
model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model.
"""
if backend == "torch":
# When loading a PEFT model, we need to load the base model first,
# but some model_args are only for the adapter
adapter_only_kwargs = {}
if is_peft_model:
for adapter_only_kwarg in ["revision"]:
if adapter_only_kwarg in model_args:
adapter_only_kwargs[adapter_only_kwarg] = model_args.pop(adapter_only_kwarg)

if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config):
Expand All @@ -138,24 +180,26 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args)

if is_peft_model:
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args, **adapter_only_kwargs)
elif backend == "onnx":
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "openvino":
self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
else:
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")

def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if is_peft_available():
from peft import PeftConfig, PeftModel
def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None:
from peft import PeftModel

if isinstance(config, PeftConfig):
self.auto_model = PeftModel.from_pretrained(
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self.auto_model = PeftModel.from_pretrained(
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)

def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_openvino_model(
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
) -> None:
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")

Expand Down Expand Up @@ -210,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar
if export:
self._backend_warn_to_save(model_name_or_path, is_local, backend_name)

def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_onnx_model(
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
) -> None:
try:
import onnxruntime as ort
from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction
Expand Down Expand Up @@ -363,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_
to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`."
logger.warning(to_log)

def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_t5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel

Expand All @@ -372,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) ->
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)

def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_mt5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
"""Loads the encoder model from T5"""
from transformers import MT5EncoderModel

Expand Down
9 changes: 9 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,12 @@ def test_multiple_adapters() -> None:
model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency")
with pytest.raises(ValueError, match="PEFT methods are only supported"):
model.add_adapter(peft_config)


@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test loading PEFT models.")
def test_load_adapter_with_revision():
model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-lora", revision="3b4f75bcb3dec36a7e05da8c44ee2f7f1d023b1a"
)
embeddings = model.encode("Hello, World!")
assert embeddings.shape == (128,)

0 comments on commit a542b0a

Please sign in to comment.