Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer #7836

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

import openai
import requests
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from typing_extensions import ParamSpec

from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip

Expand Down Expand Up @@ -60,39 +61,47 @@ class RemoteOpenAIServer:

def __init__(self,
model: str,
cli_args: List[str],
vllm_serve_args: List[str],
*,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
if not model.startswith("/"):
# download the model if it's not a local path
# to exclude the model download time from the server start time
snapshot_download(model)
if auto_port:
if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port"
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port "
"when `auto_port=True`.")

cli_args = cli_args + ["--port", str(get_open_port())]
vllm_serve_args += ["--port", str(get_open_port())]

parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you just need args = parser.parse_args(["--model", model, *cli_args])

args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost')
self.port = int(args.port)

# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config()
dummy_loader = DefaultModelLoader(engine_config.load_config)
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision,
fall_back_to_pt=True)

env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self.proc = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def from_cli_args(cls, args: argparse.Namespace):
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

def create_engine_config(self, ) -> EngineConfig:
def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf"
Expand Down
Loading