diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0ab9c63ce4377..311e60ba60f61 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,4 +1,10 @@ import pytest +from prometheus_client import REGISTRY + +from vllm import EngineArgs, LLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams MODELS = [ "facebook/opt-125m", @@ -68,3 +74,91 @@ def test_metric_counter_generation_tokens( assert vllm_generation_count == metric_count, ( f"generation token count: {vllm_generation_count!r}\n" f"metric: {metric_count!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +@pytest.mark.asyncio +async def test_async_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + """ + Regression test ensuring async engine generates metrics + when disable_log_stats=False + (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) + """ + engine_args = AsyncEngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + async_engine = AsyncLLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + results = async_engine.generate( + prompt, + SamplingParams(max_tokens=max_tokens), + f"request-id-{i}", + ) + # Exhaust the async iterator to make the async engine work + async for _ in results: + pass + + assert_metrics(async_engine.engine, disable_log_stats, + len(example_prompts)) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +def test_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + engine_args = EngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + engine = LLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + + assert_metrics(engine, disable_log_stats, len(example_prompts)) + + +def assert_metrics(engine: LLMEngine, disable_log_stats: bool, + num_requests: int) -> None: + if disable_log_stats: + with pytest.raises(AttributeError): + _ = engine.stat_logger + else: + assert (engine.stat_logger + is not None), "engine.stat_logger should be set" + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + labels = {'model_name': engine.model_config.model} + request_histogram_metrics = [ + "vllm:e2e_request_latency_seconds", + "vllm:request_prompt_tokens", + "vllm:request_generation_tokens", + "vllm:request_params_best_of", + "vllm:request_params_n", + ] + for metric_name in request_histogram_metrics: + metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", + labels) + assert ( + metric_value == num_requests), "Metrics should be collected"