From 7ee9fa84d9ddc5b5a3bf4a0784f86255be6a338a Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 26 Nov 2024 14:38:09 +0100 Subject: [PATCH] Update the lacking type-hinting in the Transformer module --- sentence_transformers/models/Transformer.py | 60 +++++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index e3651ac32..61c14e36c 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -5,17 +5,20 @@ import os from fnmatch import fnmatch from pathlib import Path -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import huggingface_hub import torch from torch import nn -from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config +from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file logger = logging.getLogger(__name__) +if TYPE_CHECKING and is_peft_available(): + from peft import PeftConfig + def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]: def wrapper(save_directory: str | Path, **kwargs) -> None: @@ -99,8 +102,21 @@ def __init__( if tokenizer_name_or_path is not None: self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ - def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]): - """Loads the configuration of a model""" + def _load_config( + self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any] + ) -> tuple[PeftConfig | PretrainedConfig, bool]: + """Loads the transformers or PEFT configuration + + Args: + model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2') + or the path to a local model directory. + cache_dir (str | None): The cache directory to store the model configuration. + backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`. + config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config. + + Returns: + tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model. + """ if ( find_adapter_config_file( model_name_or_path, @@ -127,8 +143,26 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False - def _load_model(self, model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args) -> None: - """Loads the transformer model""" + def _load_model( + self, + model_name_or_path: str, + config: PeftConfig | PretrainedConfig, + cache_dir: str, + backend: str, + is_peft_model: bool, + **model_args, + ) -> None: + """Loads the transformers or PEFT model into the `auto_model` attribute + + Args: + model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2') + or the path to a local model directory. + config ("PeftConfig" | PretrainedConfig): The model configuration. + cache_dir (str | None): The cache directory to store the model configuration. + backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`. + is_peft_model (bool): Whether the model is a PEFT model. + model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model. + """ if backend == "torch": # When loading a PEFT model, we need to load the base model first, # but some model_args are only for the adapter @@ -156,14 +190,16 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, is_peft_mo else: raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.") - def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None: from peft import PeftModel self.auto_model = PeftModel.from_pretrained( self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) - def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_openvino_model( + self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args + ) -> None: if isinstance(config, T5Config) or isinstance(config, MT5Config): raise ValueError("T5 models are not yet supported by the OpenVINO backend.") @@ -218,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar if export: self._backend_warn_to_save(model_name_or_path, is_local, backend_name) - def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_onnx_model( + self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args + ) -> None: try: import onnxruntime as ort from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction @@ -371,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_ to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`." logger.warning(to_log) - def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_t5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None: """Loads the encoder model from T5""" from transformers import T5EncoderModel @@ -380,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) - def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + def _load_mt5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None: """Loads the encoder model from T5""" from transformers import MT5EncoderModel