diff --git a/tests/utils.py b/tests/utils.py index 955431bbd3014..b73a05b5fe67f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,13 +11,14 @@ import openai import requests -from huggingface_hub import snapshot_download from transformers import AutoTokenizer from typing_extensions import ParamSpec from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip @@ -60,39 +61,50 @@ class RemoteOpenAIServer: def __init__(self, model: str, - cli_args: List[str], + vllm_serve_args: List[str], *, env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, max_wait_seconds: Optional[float] = None) -> None: - if not model.startswith("/"): - # download the model if it's not a local path - # to exclude the model download time from the server start time - snapshot_download(model) if auto_port: - if "-p" in cli_args or "--port" in cli_args: - raise ValueError("You have manually specified the port" + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError("You have manually specified the port " "when `auto_port=True`.") - cli_args = cli_args + ["--port", str(get_open_port())] + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") parser = make_arg_parser(parser) - args = parser.parse_args(cli_args) + args = parser.parse_args(["--model", model, *vllm_serve_args]) self.host = str(args.host or 'localhost') self.port = int(args.port) + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + engine_config = engine_args.create_engine_config() + dummy_loader = DefaultModelLoader(engine_config.load_config) + dummy_loader._prepare_weights(engine_config.model_config.model, + engine_config.model_config.revision, + fall_back_to_pt=True) + env = os.environ.copy() # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' if env_dict is not None: env.update(env_dict) - self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr) + self.proc = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) max_wait_seconds = max_wait_seconds or 240 self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e39d7fbc21b21..d759ce04d75e7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -743,7 +743,7 @@ def from_cli_args(cls, args: argparse.Namespace): engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args - def create_engine_config(self, ) -> EngineConfig: + def create_engine_config(self) -> EngineConfig: # gguf file needs a specific model loader and doesn't use hf_repo if self.model.endswith(".gguf"): self.quantization = self.load_format = "gguf"