Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logits processors to enable logit_bias in OpenAI server #535

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
from vllm.logits_processors import BiasLogitsProcessor

TIMEOUT_KEEP_ALIVE = 5 # seconds

Expand Down Expand Up @@ -170,7 +171,6 @@ async def create_chat_completion(raw_request: Request):

NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
logger.info(f"Received chat completion request: {request}")
Expand All @@ -179,16 +179,19 @@ async def create_chat_completion(raw_request: Request):
if error_check_ret is not None:
return error_check_ret

if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
if error_check_ret is not None:
return error_check_ret

if not request.logit_bias:
logit_processors = []
else:
biases = dict(
map(lambda bias: (int(bias[0]), bias[1]),
request.logit_bias.items()))
logit_processors = [BiasLogitsProcessor(biases)]

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
Expand All @@ -205,7 +208,7 @@ async def create_chat_completion(raw_request: Request):
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
logits_processors=logit_processors)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

Expand Down Expand Up @@ -342,7 +345,6 @@ async def create_completion(raw_request: Request):
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
Expand All @@ -362,10 +364,13 @@ async def create_completion(raw_request: Request):
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")

if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
if not request.logit_bias:
logit_processors = []
else:
logit_bias = dict(
map(lambda logit: (int(logit[0]), logit[1]),
request.logit_bias.items()))
logit_processors = [BiasLogitsProcessor(logit_bias)]

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
Expand All @@ -381,6 +386,7 @@ async def create_completion(raw_request: Request):
else:
prompt = request.prompt
created_time = int(time.time())

try:
sampling_params = SamplingParams(
n=request.n,
Expand All @@ -395,7 +401,7 @@ async def create_completion(raw_request: Request):
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
logits_processors=logit_processors)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

Expand Down
46 changes: 46 additions & 0 deletions vllm/logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict


class LogitsProcessor(ABC):

@abstractmethod
def __call__(self, logits: torch.tensor) -> torch.tensor:
pass


class BiasLogitsProcessor(LogitsProcessor):
"""This is to enable logit_bias in the OpenAI server.

biases is a dict where each value is -100 to 100
according to the OpenAI API docs.

Args:
biases: Dict ov values from -100 to 100 to scale the
probability of a token being generated.
Each key of the dict coresponds to the the token id.
"""

def __init__(self, biases: Dict[int, float]):
self.biases = biases

if not biases:
return

self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
self.values = torch.tensor(list(self.biases.values()),
dtype=torch.long)

def __call__(self, logits):
if not self.biases:
return logits

values = self.values.to(logits.device)
keys = self.keys.to(logits.device)

update_factors = torch.where(values >= 0, 1 + (values / 100),
1 / (1 - (values / 100)))
logits[0, keys] *= update_factors

return logits
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def forward(

# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())

if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
Expand All @@ -61,6 +62,9 @@ def forward(
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)

# Apply and user defined logits processors.
logits = _apply_logits_processors(input_metadata, logits, output_tokens)

# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
Expand Down Expand Up @@ -142,6 +146,22 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return output_tokens


def _apply_logits_processors(
input_metadata: InputMetadata,
logits: torch.Tensor,
output_tokens: List[List[int]]
) -> torch.Tensor:
for _, seq_group in enumerate(input_metadata.seq_groups):
_, sampling_params = seq_group
logits_processors = sampling_params.logits_processors

if logits_processors is not None:
for logits_processor in logits_processors:
logits = logits_processor(logits, output_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this call, you send output_tokens to logits_processor(). However, in the LogitsProcessor interface, the output_tokens parameter does not exist:

def __call__(self, logits: torch.tensor) -> torch.tensor:

How does it work?


return logits


def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
Expand Down
33 changes: 18 additions & 15 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
from vllm.logits_processors import LogitsProcessor

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -40,23 +41,24 @@ class SamplingParams:
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
logits_processors: List of LogitsProcessors to change the probability
of token prediction at runtime.
"""

def __init__(
self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
) -> None:
def __init__(self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
logits_processors: List[LogitsProcessor] = None) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
Expand All @@ -74,6 +76,7 @@ def __init__(
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.logits_processors = logits_processors

self._verify_args()
if self.use_beam_search:
Expand Down