diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 48a3bc768..61c14e36c 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -5,17 +5,20 @@ import os from fnmatch import fnmatch from pathlib import Path -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import huggingface_hub import torch from torch import nn -from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config +from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file logger = logging.getLogger(__name__) +if TYPE_CHECKING and is_peft_available(): + from peft import PeftConfig + def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]: def wrapper(save_directory: str | Path, **kwargs) -> None: @@ -74,8 +77,8 @@ def __init__( if config_args is None: config_args = {} - config = self._load_config(model_name_or_path, cache_dir, backend, config_args) - self._load_model(model_name_or_path, config, cache_dir, backend, **model_args) + config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args) + self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args) if max_seq_length is not None and "model_max_length" not in tokenizer_args: tokenizer_args["model_max_length"] = max_seq_length @@ -99,8 +102,21 @@ def __init__( if tokenizer_name_or_path is not None: self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ - def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]): - """Loads the configuration of a model""" + def _load_config( + self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any] + ) -> tuple[PeftConfig | PretrainedConfig, bool]: + """Loads the transformers or PEFT configuration + + Args: + model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2') + or the path to a local model directory. + cache_dir (str | None): The cache directory to store the model configuration. + backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`. + config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config. + + Returns: + tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model. + """ if ( find_adapter_config_file( model_name_or_path, @@ -123,13 +139,39 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: ) from peft import PeftConfig - return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) + return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), True + + return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False - return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) + def _load_model( + self, + model_name_or_path: str, + config: PeftConfig | PretrainedConfig, + cache_dir: str, + backend: str, + is_peft_model: bool, + **model_args, + ) -> None: + """Loads the transformers or PEFT model into the `auto_model` attribute - def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None: - """Loads the transformer model""" + Args: + model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2') + or the path to a local model directory. + config ("PeftConfig" | PretrainedConfig): The model configuration. + cache_dir (str | None): The cache directory to store the model configuration. + backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`. + is_peft_model (bool): Whether the model is a PEFT model. + model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model. + """ if backend == "torch": + # When loading a PEFT model, we need to load the base model first, + # but some model_args are only for the adapter + adapter_only_kwargs = {} + if is_peft_model: + for adapter_only_kwarg in ["revision"]: + if adapter_only_kwarg in model_args: + adapter_only_kwargs[adapter_only_kwarg] = model_args.pop(adapter_only_kwarg) + if isinstance(config, T5Config): self._load_t5_model(model_name_or_path, config, cache_dir, **model_args) elif isinstance(config, MT5Config): @@ -138,7 +180,9 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar self.auto_model = AutoModel.from_pretrained( model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) - self._load_peft_model(model_name_or_path, config, cache_dir, **model_args) + + if is_peft_model: + self._load_peft_model(model_name_or_path, config, cache_dir, **model_args, **adapter_only_kwargs) elif backend == "onnx": self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args) elif backend == "openvino": @@ -146,16 +190,16 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar else: raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.") - def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: - if is_peft_available(): - from peft import PeftConfig, PeftModel + def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None: + from peft import PeftModel - if isinstance(config, PeftConfig): - self.auto_model = PeftModel.from_pretrained( - self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args - ) + self.auto_model = PeftModel.from_pretrained( + self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args + ) - def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_openvino_model( + self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args + ) -> None: if isinstance(config, T5Config) or isinstance(config, MT5Config): raise ValueError("T5 models are not yet supported by the OpenVINO backend.") @@ -210,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar if export: self._backend_warn_to_save(model_name_or_path, is_local, backend_name) - def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_onnx_model( + self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args + ) -> None: try: import onnxruntime as ort from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction @@ -363,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_ to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`." logger.warning(to_log) - def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_t5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None: """Loads the encoder model from T5""" from transformers import T5EncoderModel @@ -372,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) - def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_mt5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None: """Loads the encoder model from T5""" from transformers import MT5EncoderModel diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 790145ac0..28636d728 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -781,3 +781,12 @@ def test_multiple_adapters() -> None: model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency") with pytest.raises(ValueError, match="PEFT methods are only supported"): model.add_adapter(peft_config) + + +@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test loading PEFT models.") +def test_load_adapter_with_revision(): + model = SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-lora", revision="3b4f75bcb3dec36a7e05da8c44ee2f7f1d023b1a" + ) + embeddings = model.encode("Hello, World!") + assert embeddings.shape == (128,)