Skip to content

Commit

Permalink
Factor out trace headers code
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Oct 29, 2024
1 parent a79d3b2 commit 8b950dd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 26 deletions.
12 changes: 2 additions & 10 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation

Expand Down Expand Up @@ -183,14 +181,8 @@ async def create_chat_completion(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
Expand Down
12 changes: 2 additions & 10 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid

Expand Down Expand Up @@ -136,14 +134,8 @@ async def create_completion(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
trace_headers = (await
self._get_trace_headers(raw_request.headers))

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
Expand Down
11 changes: 7 additions & 4 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ async def create_embedding(

tokenizer = await self.engine_client.get_tokenizer(lora_request)

if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")

if isinstance(request, EmbeddingChatRequest):
(
conversation,
Expand Down Expand Up @@ -191,16 +195,15 @@ async def create_embedding(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)

Expand Down
21 changes: 19 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Optional,
Sequence, Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)

from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated

from vllm.config import ModelConfig
Expand Down Expand Up @@ -40,6 +41,8 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of

Expand Down Expand Up @@ -522,6 +525,20 @@ def _log_inputs(
prompt_adapter_request=prompt_adapter_request,
)

async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()

if is_tracing_enabled:
return extract_trace_headers(headers)

if contains_trace_headers(headers):
log_tracing_disabled_warning()

return None

@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
Expand Down

0 comments on commit 8b950dd

Please sign in to comment.