From 56577b2d59b7b8c1d38fa478ef35a35fba107c80 Mon Sep 17 00:00:00 2001 From: Jared Moore <27744679+jlcmoore@users.noreply.github.com> Date: Sun, 25 Feb 2024 18:39:34 -0800 Subject: [PATCH] Add LogProbs for Chat Completions in OpenAI (#2918) --- tests/entrypoints/test_openai_server.py | 25 ++++++++-------- vllm/entrypoints/openai/protocol.py | 8 ++++++ vllm/entrypoints/openai/serving_chat.py | 38 +++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 3a359502c39d5..29d0e6fd537d5 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, }] # test single completion - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_tokens=10, - ) + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=10) assert chat_completion.id is not None assert chat_completion.choices is not None and len( chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.top_logprobs is not None + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, single_output = single_completion.choices[0].text single_usage = single_completion.usage - stream = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - ) + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) chunks = [] async for chunk in stream: chunks.append(chunk.choices[0].text) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7c2aa707775ff..f57a2fb775783 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 def to_sampling_params(self) -> SamplingParams: + if self.logprobs and not self.top_logprobs: + raise ValueError("Top logprobs must be set when logprobs is.") return SamplingParams( n=self.n, presence_penalty=self.presence_penalty, @@ -96,6 +100,8 @@ def to_sampling_params(self) -> SamplingParams: stop=self.stop, stop_token_ids=self.stop_token_ids, max_tokens=self.max_tokens, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=self.top_logprobs if self.echo else None, best_of=self.best_of, top_k=self.top_k, ignore_eos=self.ignore_eos, @@ -216,6 +222,7 @@ class ChatMessage(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None @@ -236,6 +243,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 850797ae4b9b6..dd152583c2329 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -101,7 +101,10 @@ async def chat_completion_stream_generator( role = self.get_chat_request_role(request) for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=DeltaMessage(role=role), finish_reason=None) + index=i, + delta=DeltaMessage(role=role), + logprobs=None, + finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, object=chunk_object_type, created=created_time, @@ -118,6 +121,7 @@ async def chat_completion_stream_generator( "content") and request.messages[-1].get( "role") == role: last_msg_content = request.messages[-1]["content"] + if last_msg_content: for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( @@ -129,6 +133,7 @@ async def chat_completion_stream_generator( object=chunk_object_type, created=created_time, choices=[choice_data], + logprobs=None, model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -145,15 +150,29 @@ async def chat_completion_stream_generator( if finish_reason_sent[i]: continue + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - if output.finish_reason is None: # Send token-by-token response for each request.n choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), + logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, @@ -174,6 +193,7 @@ async def chat_completion_stream_generator( choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), + logprobs=logprobs, finish_reason=output.finish_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -208,11 +228,25 @@ async def chat_completion_full_generator( assert final_res is not None choices = [] + role = self.get_chat_request_role(request) for output in final_res.outputs: + token_ids = output.token_ids + top_logprobs = output.logprobs + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + choice_data = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role=role, content=output.text), + logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data)