diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 29d0e6fd537d5..72e2374899793 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -9,6 +9,8 @@ import openai # use the official client for correctness check from huggingface_hub import snapshot_download # downloading lora to test lora requests +from vllm.transformers_utils.tokenizer import get_tokenizer + MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here @@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +async def test_logits_bias(server, client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f57a2fb775783..e85e7e2b1ede9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -8,6 +8,8 @@ from vllm.utils import random_uuid from vllm.sampling_params import SamplingParams +import torch + class ErrorResponse(BaseModel): object: str = "error" @@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel): def to_sampling_params(self) -> SamplingParams: if self.logprobs and not self.top_logprobs: raise ValueError("Top logprobs must be set when logprobs is.") + + logits_processors = None + if self.logit_bias: + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in self.logit_bias.items(): + # Clamp the bias between -100 and 100 per OpenAI API spec + bias = min(100, max(-100, bias)) + logits[int(token_id)] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + return SamplingParams( n=self.n, presence_penalty=self.presence_penalty, @@ -111,6 +128,7 @@ def to_sampling_params(self) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, + logits_processors=logits_processors, ) @@ -149,6 +167,20 @@ class CompletionRequest(BaseModel): def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 + logits_processors = None + if self.logit_bias: + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in self.logit_bias.items(): + # Clamp the bias between -100 and 100 per OpenAI API spec + bias = min(100, max(-100, bias)) + logits[int(token_id)] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + return SamplingParams( n=self.n, best_of=self.best_of, @@ -172,6 +204,7 @@ def to_sampling_params(self): spaces_between_special_tokens=(self.spaces_between_special_tokens), include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, + logits_processors=logits_processors, ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dd152583c2329..5635ac6c9e106 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -39,19 +39,13 @@ async def create_chat_completion( See https://platform.openai.com/docs/api-reference/chat/create for the API specification. This API mimics the OpenAI ChatCompletion API. - NOTE: Currently we do not support the following features: + NOTE: Currently we do not support the following feature: - function_call (Users should implement this by themselves) - - logit_bias (to be supported by vLLM engine) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - if request.logit_bias is not None and len(request.logit_bias) > 0: - # TODO: support logit_bias in vLLM engine. - return self.create_error_response( - "logit_bias is not currently supported") - try: prompt = self.tokenizer.apply_chat_template( conversation=request.messages, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 667b659f81e9e..610f53549da48 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -264,10 +264,9 @@ async def create_completion(self, request: CompletionRequest, See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. - NOTE: Currently we do not support the following features: + NOTE: Currently we do not support the following feature: - suffix (the language models we currently support do not support suffix) - - logit_bias (to be supported by vLLM engine) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -277,9 +276,6 @@ async def create_completion(self, request: CompletionRequest, if request.suffix is not None: return self.create_error_response( "suffix is not currently supported") - if request.logit_bias is not None and len(request.logit_bias) > 0: - return self.create_error_response( - "logit_bias is not currently supported") model_name = request.model request_id = f"cmpl-{random_uuid()}"