Skip to content

Commit

Permalink
[Bugfix] Fix offline mode when using mistral_common (vllm-project#9457
Browse files Browse the repository at this point in the history
)

Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
sasha0552 authored and garg-amit committed Oct 28, 2024
1 parent 2ff2ea7 commit 6e47f0c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 28 deletions.
56 changes: 31 additions & 25 deletions tests/entrypoints/offline_mode/test_offline_mode.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,56 @@
"""Tests for HF_HUB_OFFLINE mode"""
import importlib
import sys
import weakref

import pytest

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "facebook/opt-125m"
MODEL_CONFIGS = [
{
"model": "facebook/opt-125m",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
{
"model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True,
"gpu_memory_utilization": 0.95,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
]


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
def cache_models():
# Cache model files first
for model_config in MODEL_CONFIGS:
LLM(**model_config)
cleanup_dist_env_and_memory()

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

del llm

cleanup_dist_env_and_memory()
yield


@pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch):
# we use the llm fixture to ensure the model files are in-cache
del llm

@pytest.mark.usefixtures("cache_models")
def test_offline_mode(monkeypatch):
# Set HF to offline mode and ensure we can still construct an LLM
try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules()
# Cached model files should be used in offline mode
LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.20,
enforce_eager=True)
for model_config in MODEL_CONFIGS:
LLM(**model_config)
finally:
# Reset the environment after the test
# NB: Assuming tests are run in online mode
Expand Down
34 changes: 31 additions & 3 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast

import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# yapf: disable
Expand All @@ -24,6 +25,26 @@ class Encoding:
input_ids: List[int]


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]))

if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()

if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)

return []


def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")

Expand Down Expand Up @@ -90,9 +111,16 @@ def from_pretrained(cls,
@staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str:
api = HfApi()
repo_info = api.model_info(tokenizer_name)
files = [s.rfilename for s in repo_info.siblings]
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name,
revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name,
revision=revision)

if len(files) == 0:
raise exc

filename = find_tokenizer_file(files)

Expand Down

0 comments on commit 6e47f0c

Please sign in to comment.