diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053260ec..0c96caa714623 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -94,7 +94,7 @@ steps: command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test - command: pytest -v -s metrics + command: pytest -v -s metrics --forked - label: Quantization Test command: pytest -v -s quantization diff --git a/docs/source/conf.py b/docs/source/conf.py index 9da5a4991734d..e707c172fd957 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ def setup(app): "torch", "transformers", "psutil", - "prometheus_client", + "aioprometheus", "sentencepiece", "vllm.cuda_utils", "vllm._C", diff --git a/requirements-common.txt b/requirements-common.txt index 3abb828116680..847ac8f360cae 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,8 +11,7 @@ fastapi openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. -prometheus_client >= 0.18.0 -prometheus-fastapi-instrumentator >= 7.0.0 +aioprometheus[starlette] tiktoken == 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.9.8 outlines == 0.0.34 # Requires torch >= 2.1.0 diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0ab9c63ce4377..902532d476bc2 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -30,9 +30,8 @@ def test_metric_counter_prompt_tokens( _ = vllm_model.generate_greedy(example_prompts, max_tokens) stat_logger = vllm_model.model.llm_engine.stat_logger - metric_count = stat_logger.metrics.counter_prompt_tokens.labels( - **stat_logger.labels)._value.get() - + metric_count = stat_logger.metrics.counter_prompt_tokens.get_value( + stat_logger.labels) assert vllm_prompt_token_count == metric_count, ( f"prompt token count: {vllm_prompt_token_count!r}\n" f"metric: {metric_count!r}") @@ -55,8 +54,8 @@ def test_metric_counter_generation_tokens( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) tokenizer = vllm_model.model.get_tokenizer() stat_logger = vllm_model.model.llm_engine.stat_logger - metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() + metric_count = stat_logger.metrics.counter_generation_tokens.get_value( + stat_logger.labels) vllm_generation_count = 0 for i in range(len(example_prompts)): vllm_output_ids, vllm_output_str = vllm_outputs[i] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 835803fd4e75d..db1eed5b65211 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -220,7 +220,6 @@ def __init__( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, labels=dict(model_name=model_config.model), max_model_len=self.model_config.max_model_len) - self.stat_logger.info("cache_config", self.cache_config) # Create sequence output processor, e.g. for beam search or # speculative decoding. diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 45bfad03ec867..eb9cd0e638ebe 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -2,11 +2,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Protocol, Union +from typing import Dict, List, Optional, Union import numpy as np -from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, - disable_created_metrics) +from aioprometheus import Counter, Gauge, Histogram from vllm.logger import init_logger @@ -15,8 +14,6 @@ logger = init_logger(__name__) -disable_created_metrics() - # The begin-* and end* here are used by the documentation generator # to extract the metrics definitions. @@ -25,62 +22,41 @@ class Metrics: labelname_finish_reason = "finished_reason" - def __init__(self, labelnames: List[str], max_model_len: int): - # Unregister any existing vLLM collectors - for collector in list(REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - REGISTRY.unregister(collector) - - # Config Information - self.info_cache_config = Info( - name='vllm:cache_config', - documentation='information of cache_config') - + def __init__(self, max_model_len: int): # System stats # Scheduler State self.gauge_scheduler_running = Gauge( - name="vllm:num_requests_running", - documentation="Number of requests currently running on GPU.", - labelnames=labelnames) + "vllm:num_requests_running", + "Number of requests currently running on GPU.") self.gauge_scheduler_waiting = Gauge( - name="vllm:num_requests_waiting", - documentation="Number of requests waiting to be processed.", - labelnames=labelnames) + "vllm:num_requests_waiting", + "Number of requests waiting to be processed.") self.gauge_scheduler_swapped = Gauge( - name="vllm:num_requests_swapped", - documentation="Number of requests swapped to CPU.", - labelnames=labelnames) + "vllm:num_requests_swapped", "Number of requests swapped to CPU.") # KV Cache Usage in % self.gauge_gpu_cache_usage = Gauge( - name="vllm:gpu_cache_usage_perc", - documentation="GPU KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) + "vllm:gpu_cache_usage_perc", + "GPU KV-cache usage. 1 means 100 percent usage.") self.gauge_cpu_cache_usage = Gauge( - name="vllm:cpu_cache_usage_perc", - documentation="CPU KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) + "vllm:cpu_cache_usage_perc", + "CPU KV-cache usage. 1 means 100 percent usage.") # Iteration stats self.counter_prompt_tokens = Counter( - name="vllm:prompt_tokens_total", - documentation="Number of prefill tokens processed.", - labelnames=labelnames) + "vllm:prompt_tokens_total", "Number of prefill tokens processed.") self.counter_generation_tokens = Counter( - name="vllm:generation_tokens_total", - documentation="Number of generation tokens processed.", - labelnames=labelnames) + "vllm:generation_tokens_total", + "Number of generation tokens processed.") self.histogram_time_to_first_token = Histogram( - name="vllm:time_to_first_token_seconds", - documentation="Histogram of time to first token in seconds.", - labelnames=labelnames, + "vllm:time_to_first_token_seconds", + "Histogram of time to first token in seconds.", buckets=[ 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 ]) self.histogram_time_per_output_token = Histogram( - name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", - labelnames=labelnames, + "vllm:time_per_output_token_seconds", + "Histogram of time per output token in seconds.", buckets=[ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 @@ -89,51 +65,42 @@ def __init__(self, labelnames: List[str], max_model_len: int): # Request stats # Latency self.histogram_e2e_time_request = Histogram( - name="vllm:e2e_request_latency_seconds", - documentation="Histogram of end to end request latency in seconds.", - labelnames=labelnames, + "vllm:e2e_request_latency_seconds", + "Histogram of end to end request latency in seconds.", buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) # Metadata self.histogram_num_prompt_tokens_request = Histogram( - name="vllm:request_prompt_tokens", - documentation="Number of prefill tokens processed.", - labelnames=labelnames, + "vllm:request_prompt_tokens", + "Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), ) self.histogram_num_generation_tokens_request = Histogram( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, + "vllm:request_generation_tokens", + "Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), ) self.histogram_best_of_request = Histogram( - name="vllm:request_params_best_of", - documentation="Histogram of the best_of request parameter.", - labelnames=labelnames, + "vllm:request_params_best_of", + "Histogram of the best_of request parameter.", buckets=[1, 2, 5, 10, 20], ) self.histogram_n_request = Histogram( - name="vllm:request_params_n", - documentation="Histogram of the n request parameter.", - labelnames=labelnames, + "vllm:request_params_n", + "Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], ) self.counter_request_success = Counter( - name="vllm:request_success", - documentation="Count of successfully processed requests.", - labelnames=labelnames + [Metrics.labelname_finish_reason]) + "vllm:request_success_total", + "Count of successfully processed requests.") # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = Gauge( - name="vllm:avg_prompt_throughput_toks_per_s", - documentation="Average prefill throughput in tokens/s.", - labelnames=labelnames, - ) + "vllm:avg_prompt_throughput_toks_per_s", + "Average prefill throughput in tokens/s.") # Deprecated in favor of vllm:generation_tokens_total self.gauge_avg_generation_throughput = Gauge( - name="vllm:avg_generation_throughput_toks_per_s", - documentation="Average generation throughput in tokens/s.", - labelnames=labelnames, + "vllm:avg_generation_throughput_toks_per_s", + "Average generation throughput in tokens/s.", ) @@ -195,12 +162,6 @@ class Stats: spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None -class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> Dict[str, str]: - ... - - class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" @@ -216,12 +177,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str], # Prometheus metrics self.labels = labels - self.metrics = Metrics(labelnames=list(labels.keys()), - max_model_len=max_model_len) - - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - if type == "cache_config": - self.metrics.info_cache_config.info(obj.metrics_info()) + self.metrics = Metrics(max_model_len=max_model_len) def _get_throughput(self, tracked_stats: List[int], now: float) -> float: return float(np.sum(tracked_stats) / (now - self.last_local_log)) @@ -274,23 +230,23 @@ def _log_prometheus(self, stats: Stats) -> None: def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. - gauge.labels(**self.labels).set(data) + gauge.set(self.labels, data) def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. - counter.labels(**self.labels).inc(data) + counter.add(self.labels, data) def _log_counter_labels(self, counter: Counter, data: CollectionsCounter, label_key: str) -> None: # Convenience function for collection counter of labels. for label, count in data.items(): - counter.labels(**{**self.labels, label_key: label}).inc(count) + counter.add({**self.labels, label_key: label}, count) def _log_histogram(self, histogram: Histogram, data: Union[List[int], List[float]]) -> None: # Convenience function for logging list to histogram. for datum in data: - histogram.labels(**self.labels).observe(datum) + histogram.observe(self.labels, datum) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: @@ -301,10 +257,10 @@ def _log_prometheus_interval(self, prompt_throughput: float, # Which log raw data and calculate summaries using rate() on the # grafana/prometheus side. See # https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 - self.metrics.gauge_avg_prompt_throughput.labels( - **self.labels).set(prompt_throughput) - self.metrics.gauge_avg_generation_throughput.labels( - **self.labels).set(generation_throughput) + self._log_gauge(self.metrics.gauge_avg_prompt_throughput, + prompt_throughput) + self._log_gauge(self.metrics.gauge_avg_generation_throughput, + generation_throughput) def log(self, stats: Stats) -> None: """Called by LLMEngine. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index af9ba7a3bc825..0ba4aa3824363 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -7,11 +7,12 @@ import fastapi import uvicorn +from aioprometheus import MetricsMiddleware +from aioprometheus.asgi.starlette import metrics from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app import vllm from vllm.engine.arg_utils import AsyncEngineArgs @@ -54,9 +55,8 @@ def parse_args(): return parser.parse_args() -# Add prometheus asgi middleware to route /metrics requests -metrics_app = make_asgi_app() -app.mount("/metrics", metrics_app) +app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics +app.add_route("/metrics", metrics) # Exposes HTTP metrics @app.exception_handler(RequestValidationError)