Skip to content

Commit

Permalink
[Frontend] New allowed_token_ids decoding request parameter (vllm-p…
Browse files Browse the repository at this point in the history
…roject#6753)

Signed-off-by: Alvant <[email protected]>
  • Loading branch information
njhill authored and Alvant committed Oct 26, 2024
1 parent acb14eb commit c34fff2
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 46 deletions.
22 changes: 22 additions & 0 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
Expand Down
74 changes: 74 additions & 0 deletions vllm/entrypoints/openai/logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from functools import lru_cache
from typing import Dict, FrozenSet, Iterable, List, Optional, Union

import torch
from transformers import PreTrainedTokenizer

from vllm.sampling_params import LogitsProcessor


class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""

def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None

def __call__(self, token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits


@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: FrozenSet[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)


def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc

def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():
logits[token_id] += bias
return logits

logits_processors.append(logit_bias_logits_processor)

if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))

return logits_processors
60 changes: 16 additions & 44 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
Expand Down Expand Up @@ -213,30 +215,15 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_sampling_params(self) -> SamplingParams:
def to_sampling_params(self,
tokenizer: PreTrainedTokenizer) -> SamplingParams:
# We now allow logprobs being true without top_logrobs.

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer before we add to LLMEngine
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)

return SamplingParams(
n=self.n,
Expand Down Expand Up @@ -358,6 +345,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -407,30 +395,14 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_sampling_params(self):
def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)

return SamplingParams(
n=self.n,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def create_chat_completion(

request_id = f"chat-{random_uuid()}"
try:
sampling_params = request.to_sampling_params()
sampling_params = request.to_sampling_params(tokenizer)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def create_completion(self, request: CompletionRequest,

tokenizer = await self.engine.get_tokenizer(lora_request)

sampling_params = request.to_sampling_params()
sampling_params = request.to_sampling_params(tokenizer)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
Expand Down

0 comments on commit c34fff2

Please sign in to comment.