From a74dee9b62d10767eb0580f196f5e508e9e80a2d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 26 Apr 2024 10:10:48 +0800 Subject: [PATCH] [Bugfix] Fix parameter name in `get_tokenizer` (#4107) --- tests/tokenization/test_tokenizer.py | 20 ++++++++++++++++++++ vllm/transformers_utils/tokenizer.py | 11 ++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 tests/tokenization/test_tokenizer.py diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py new file mode 100644 index 0000000000000..8db7204f15d4e --- /dev/null +++ b/tests/tokenization/test_tokenizer.py @@ -0,0 +1,20 @@ +import pytest +from transformers import PreTrainedTokenizerBase + +from vllm.transformers_utils.tokenizer import get_tokenizer + +TOKENIZER_NAMES = [ + "facebook/opt-125m", + "gpt2", +] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +def test_tokenizer_revision(tokenizer_name: str): + # Assume that "main" branch always exists + tokenizer = get_tokenizer(tokenizer_name, revision="main") + assert isinstance(tokenizer, PreTrainedTokenizerBase) + + # Assume that "never" branch always does not exist + with pytest.raises(OSError, match='not a valid git identifier'): + get_tokenizer(tokenizer_name, revision="never") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index c98a673bfed4b..afc02c434dd43 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -58,11 +58,12 @@ def get_tokenizer( *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tokenizer_revision: Optional[str] = None, + revision: Optional[str] = None, download_dir: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface/modelscope.""" + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ if VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -74,7 +75,7 @@ def get_tokenizer( tokenizer_path = snapshot_download( model_id=tokenizer_name, cache_dir=download_dir, - revision=tokenizer_revision, + revision=revision, # Ignore weights - we only need the tokenizer. ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = tokenizer_path @@ -90,7 +91,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) except ValueError as e: # If the error pertains to the tokenizer class not existing or not @@ -114,7 +115,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) else: raise e