Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogProbs for Chat Completions in OpenAI #2918

Merged
merged 18 commits into from
Feb 26, 2024
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this should be bool type in chat completion. And we should make sampling params accept both.

presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down Expand Up @@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams:
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.logprobs,
prompt_logprobs=self.logprobs if self.echo else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prompt_logprobs be assigned by top_logprobs?

best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
Expand Down Expand Up @@ -208,6 +211,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -228,6 +232,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
72 changes: 65 additions & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union, Dict, Callable
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA

logger = init_logger(__name__)

TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]

class OpenAIServingChat(OpenAIServing):

Expand Down Expand Up @@ -77,10 +82,10 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
request, result_generator, request_id, self._create_logprobs)
else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id, self._create_logprobs)

def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
Expand All @@ -90,7 +95,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
result_generator: AsyncIterator[RequestOutput], request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:

model_name = request.model
Expand All @@ -101,7 +107,7 @@ async def chat_completion_stream_generator(
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i, delta=DeltaMessage(role=role), logprobs=None, finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
Expand All @@ -118,6 +124,7 @@ async def chat_completion_stream_generator(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]

if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
Expand All @@ -129,6 +136,7 @@ async def chat_completion_stream_generator(
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand All @@ -147,13 +155,37 @@ async def chat_completion_stream_generator(

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good call. I just added that in this commit

previous_num_tokens[i] = len(output.token_ids)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this line down.


if request.echo and request.max_tokens == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's no need to consider echo prompt logprobs in streaming chunks. It's also simpler without these.

delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
elif request.echo and request.max_tokens > 0:
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
else:
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if request.logprobs:

assert(top_logprobs is not None),\
"top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not invoke self._create_logprobs directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what serving_completion.py did. I can only speculate on their design decision but perhaps it was to allow for extensibility with other log probability formats in the future--not just the open ai ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to change it if you think that is best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this function is already defined in a class, so there's no need to delegate a callable function here. Finally we can delete this and relevant things and invoke parent method _create_logprobs directly.

token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None

previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -174,6 +206,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -193,7 +226,8 @@ async def chat_completion_stream_generator(
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = request.model
created_time = int(time.monotonic())
Expand All @@ -208,11 +242,35 @@ async def chat_completion_full_generator(
assert final_res is not None

choices = []

prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs

role = self.get_chat_request_role(request)
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
else:
token_ids = output.token_ids
top_logprobs = output.logprobs

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None

choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
Expand Down
Loading