diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index fe00640c0021e..50add84087a95 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -541,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 0000000000000..31eb5aa628c52 --- /dev/null +++ b/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,74 @@ +from functools import lru_cache +from typing import Dict, FrozenSet, Iterable, List, Optional, Union + +import torch +from transformers import PreTrainedTokenizer + +from vllm.sampling_params import LogitsProcessor + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[List[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: FrozenSet[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def get_logits_processors( + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: + logits_processors = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: Dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + def logit_bias_logits_processor(token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in clamped_logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors.append(logit_bias_logits_processor) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), tokenizer.vocab_size)) + + return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c024bbc07c069..205860aa8e722 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,9 +5,11 @@ import torch from pydantic import BaseModel, ConfigDict, Field, model_validator +from transformers import PreTrainedTokenizer from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -213,30 +215,15 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self) -> SamplingParams: + def to_sampling_params(self, + tokenizer: PreTrainedTokenizer) -> SamplingParams: # We now allow logprobs being true without top_logrobs. - logits_processors = None - if self.logit_bias: - logit_bias: Dict[int, float] = {} - try: - for token_id, bias in self.logit_bias.items(): - # Convert token_id to integer before we add to LLMEngine - # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias[int(token_id)] = min(100, max(-100, bias)) - except ValueError as exc: - raise ValueError(f"Found token_id `{token_id}` in logit_bias " - f"but token_id must be an integer or string " - f"representing an integer") from exc - - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + logits_processors = get_logits_processors( + logit_bias=self.logit_bias, + allowed_token_ids=None, + tokenizer=tokenizer, + ) return SamplingParams( n=self.n, @@ -358,6 +345,7 @@ class CompletionRequest(OpenAIBaseModel): skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[List[int]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -407,30 +395,14 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self): + def to_sampling_params(self, tokenizer: PreTrainedTokenizer): echo_without_generation = self.echo and self.max_tokens == 0 - logits_processors = None - if self.logit_bias: - logit_bias: Dict[int, float] = {} - try: - for token_id, bias in self.logit_bias.items(): - # Convert token_id to integer - # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias[int(token_id)] = min(100, max(-100, bias)) - except ValueError as exc: - raise ValueError(f"Found token_id `{token_id}` in logit_bias " - f"but token_id must be an integer or string " - f"representing an integer") from exc - - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + logits_processors = get_logits_processors( + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids, + tokenizer=tokenizer, + ) return SamplingParams( n=self.n, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 012f70e661100..01843930bf11d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -134,7 +134,7 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - sampling_params = request.to_sampling_params() + sampling_params = request.to_sampling_params(tokenizer) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 73e420141813e..8548352791680 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -95,7 +95,7 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) - sampling_params = request.to_sampling_params() + sampling_params = request.to_sampling_params(tokenizer) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend