Skip to content

Commit

Permalink
[Bugfix] Fix Ray Metrics API usage (vllm-project#6354)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
Yard1 authored and Alvant committed Oct 26, 2024
1 parent eb9560b commit a4fcb94
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 40 deletions.
54 changes: 54 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List

import pytest
import ray
from prometheus_client import REGISTRY

from vllm import EngineArgs, LLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams

MODELS = [
Expand Down Expand Up @@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels)
assert (
metric_value == num_requests), "Metrics should be collected"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.

# We have to run in a Ray task for Ray metrics to be emitted correctly
@ray.remote(num_gpus=1)
def _inner():

class _RayPrometheusStatLogger(RayPrometheusStatLogger):

def __init__(self, *args, **kwargs):
self._i = 0
super().__init__(*args, **kwargs)

def log(self, *args, **kwargs):
self._i += 1
return super().log(*args, **kwargs)

engine_args = EngineArgs(
model=model,
dtype=dtype,
disable_log_stats=False,
)
engine = LLMEngine.from_engine_args(engine_args)
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(
f"request-id-{i}",
prompt,
SamplingParams(max_tokens=max_tokens),
)
while engine.has_unfinished_requests():
engine.step()
assert logger._i > 0, ".log must be called at least once"

ray.get(_inner.remote())
19 changes: 19 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
Expand Down Expand Up @@ -429,6 +430,7 @@ def from_engine_args(
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -491,6 +493,7 @@ def from_engine_args(
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

Expand Down Expand Up @@ -997,3 +1000,19 @@ async def is_tracing_enabled(self) -> bool:
)
else:
return self.engine.is_tracing_enabled()

def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)

def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -423,6 +424,7 @@ def from_engine_args(
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

Expand Down
Loading

0 comments on commit a4fcb94

Please sign in to comment.