Skip to content

Commit

Permalink
Support logit bias for OpenAI API (vllm-project#3027)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanwhawk authored Feb 27, 2024
1 parent f7382f6 commit ac493f1
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 12 deletions.
48 changes: 48 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests

from vllm.transformers_utils.tokenizer import get_tokenizer

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
Expand Down Expand Up @@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1]


async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])

# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text


if __name__ == "__main__":
pytest.main([__file__])
33 changes: 33 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams

import torch


class ErrorResponse(BaseModel):
object: str = "error"
Expand Down Expand Up @@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")

logits_processors = None
if self.logit_bias:

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits

logits_processors = [logit_bias_logits_processor]

return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
Expand All @@ -111,6 +128,7 @@ def to_sampling_params(self) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)


Expand Down Expand Up @@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None
if self.logit_bias:

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits

logits_processors = [logit_bias_logits_processor]

return SamplingParams(
n=self.n,
best_of=self.best_of,
Expand All @@ -172,6 +204,7 @@ def to_sampling_params(self):
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)


Expand Down
8 changes: 1 addition & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@ async def create_chat_completion(
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")

try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
Expand Down
6 changes: 1 addition & 5 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,9 @@ async def create_completion(self, request: CompletionRequest,
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand All @@ -277,9 +276,6 @@ async def create_completion(self, request: CompletionRequest,
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
Expand Down

0 comments on commit ac493f1

Please sign in to comment.