-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LogProbs for Chat Completions in OpenAI #2918
Changes from 1 commit
c3e2a7e
1081fe1
1ec5d5b
45eb84f
47c4894
73bcff0
0f460a3
351d773
e61779b
2c4c353
5bddae4
82a6e38
a204be7
923e5ad
2062963
8370a7e
63ff99c
e63b560
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel): | |
max_tokens: Optional[int] = None | ||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) | ||
stream: Optional[bool] = False | ||
logprobs: Optional[int] = None | ||
presence_penalty: Optional[float] = 0.0 | ||
frequency_penalty: Optional[float] = 0.0 | ||
logit_bias: Optional[Dict[str, float]] = None | ||
|
@@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams: | |
stop=self.stop, | ||
stop_token_ids=self.stop_token_ids, | ||
max_tokens=self.max_tokens, | ||
logprobs=self.logprobs, | ||
prompt_logprobs=self.logprobs if self.echo else None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
||
best_of=self.best_of, | ||
top_k=self.top_k, | ||
ignore_eos=self.ignore_eos, | ||
|
@@ -208,6 +211,7 @@ class ChatMessage(BaseModel): | |
class ChatCompletionResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
logprobs: Optional[LogProbs] = None | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
|
@@ -228,6 +232,7 @@ class DeltaMessage(BaseModel): | |
class ChatCompletionResponseStreamChoice(BaseModel): | ||
index: int | ||
delta: DeltaMessage | ||
logprobs: Optional[LogProbs] = None | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,25 @@ | ||
import time | ||
import codecs | ||
from fastapi import Request | ||
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union | ||
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union, Dict, Callable | ||
from vllm.logger import init_logger | ||
from vllm.utils import random_uuid | ||
from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
from vllm.entrypoints.openai.protocol import ( | ||
ChatCompletionRequest, ChatCompletionResponse, | ||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, | ||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, | ||
LogProbs, | ||
UsageInfo) | ||
from vllm.outputs import RequestOutput | ||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA | ||
|
||
logger = init_logger(__name__) | ||
|
||
TypeTokenIDs = List[int] | ||
TypeTopLogProbs = List[Optional[Dict[int, float]]] | ||
TypeCreateLogProbsFn = Callable[ | ||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] | ||
|
||
class OpenAIServingChat(OpenAIServing): | ||
|
||
|
@@ -77,10 +82,10 @@ async def create_chat_completion( | |
# Streaming response | ||
if request.stream: | ||
return self.chat_completion_stream_generator( | ||
request, result_generator, request_id) | ||
request, result_generator, request_id, self._create_logprobs) | ||
else: | ||
return await self.chat_completion_full_generator( | ||
request, raw_request, result_generator, request_id) | ||
request, raw_request, result_generator, request_id, self._create_logprobs) | ||
|
||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: | ||
if request.add_generation_prompt: | ||
|
@@ -90,7 +95,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: | |
|
||
async def chat_completion_stream_generator( | ||
self, request: ChatCompletionRequest, | ||
result_generator: AsyncIterator[RequestOutput], request_id: str | ||
result_generator: AsyncIterator[RequestOutput], request_id: str, | ||
create_logprobs_fn: TypeCreateLogProbsFn | ||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: | ||
|
||
model_name = request.model | ||
|
@@ -101,7 +107,7 @@ async def chat_completion_stream_generator( | |
role = self.get_chat_request_role(request) | ||
for i in range(request.n): | ||
choice_data = ChatCompletionResponseStreamChoice( | ||
index=i, delta=DeltaMessage(role=role), finish_reason=None) | ||
index=i, delta=DeltaMessage(role=role), logprobs=None, finish_reason=None) | ||
chunk = ChatCompletionStreamResponse(id=request_id, | ||
object=chunk_object_type, | ||
created=created_time, | ||
|
@@ -118,6 +124,7 @@ async def chat_completion_stream_generator( | |
"content") and request.messages[-1].get( | ||
"role") == role: | ||
last_msg_content = request.messages[-1]["content"] | ||
|
||
if last_msg_content: | ||
for i in range(request.n): | ||
choice_data = ChatCompletionResponseStreamChoice( | ||
|
@@ -129,6 +136,7 @@ async def chat_completion_stream_generator( | |
object=chunk_object_type, | ||
created=created_time, | ||
choices=[choice_data], | ||
logprobs=None, | ||
model=model_name) | ||
data = chunk.model_dump_json(exclude_unset=True) | ||
yield f"data: {data}\n\n" | ||
|
@@ -147,13 +155,37 @@ async def chat_completion_stream_generator( | |
|
||
delta_text = output.text[len(previous_texts[i]):] | ||
previous_texts[i] = output.text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to move this line down as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes good call. I just added that in this commit |
||
previous_num_tokens[i] = len(output.token_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just moved this line down. |
||
|
||
if request.echo and request.max_tokens == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's no need to consider echo prompt logprobs in streaming chunks. It's also simpler without these. |
||
delta_token_ids = res.prompt_token_ids | ||
top_logprobs = res.prompt_logprobs | ||
elif request.echo and request.max_tokens > 0: | ||
delta_token_ids = res.prompt_token_ids + output.token_ids | ||
top_logprobs = res.prompt_logprobs + (output.logprobs or []) | ||
else: | ||
delta_token_ids = output.token_ids[previous_num_tokens[i]:] | ||
top_logprobs = output.logprobs[ | ||
previous_num_tokens[i]:] if output.logprobs else None | ||
|
||
if request.logprobs is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
assert(top_logprobs is not None),\ | ||
"top_logprobs must be provided when logprobs is requested" | ||
logprobs = create_logprobs_fn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not invoke There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am happy to change it if you think that is best. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because this function is already defined in a class, so there's no need to delegate a callable function here. Finally we can delete this and relevant things and invoke parent method |
||
token_ids=delta_token_ids, | ||
top_logprobs=top_logprobs, | ||
num_output_top_logprobs=request.logprobs, | ||
initial_text_offset=len(previous_texts[i]), | ||
) | ||
else: | ||
logprobs = None | ||
|
||
previous_num_tokens[i] = len(output.token_ids) | ||
if output.finish_reason is None: | ||
# Send token-by-token response for each request.n | ||
choice_data = ChatCompletionResponseStreamChoice( | ||
index=i, | ||
delta=DeltaMessage(content=delta_text), | ||
logprobs=logprobs, | ||
finish_reason=None) | ||
chunk = ChatCompletionStreamResponse( | ||
id=request_id, | ||
|
@@ -174,6 +206,7 @@ async def chat_completion_stream_generator( | |
choice_data = ChatCompletionResponseStreamChoice( | ||
index=i, | ||
delta=DeltaMessage(content=delta_text), | ||
logprobs=logprobs, | ||
finish_reason=output.finish_reason) | ||
chunk = ChatCompletionStreamResponse( | ||
id=request_id, | ||
|
@@ -193,7 +226,8 @@ async def chat_completion_stream_generator( | |
async def chat_completion_full_generator( | ||
self, request: ChatCompletionRequest, raw_request: Request, | ||
result_generator: AsyncIterator[RequestOutput], | ||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: | ||
request_id: str, | ||
create_logprobs_fn: TypeCreateLogProbsFn) -> Union[ErrorResponse, ChatCompletionResponse]: | ||
|
||
model_name = request.model | ||
created_time = int(time.monotonic()) | ||
|
@@ -208,11 +242,35 @@ async def chat_completion_full_generator( | |
assert final_res is not None | ||
|
||
choices = [] | ||
|
||
prompt_token_ids = final_res.prompt_token_ids | ||
prompt_logprobs = final_res.prompt_logprobs | ||
|
||
role = self.get_chat_request_role(request) | ||
for output in final_res.outputs: | ||
if request.echo and request.max_tokens == 0: | ||
token_ids = prompt_token_ids | ||
top_logprobs = prompt_logprobs | ||
elif request.echo and request.max_tokens > 0: | ||
token_ids = prompt_token_ids + output.token_ids | ||
top_logprobs = prompt_logprobs + output.logprobs | ||
else: | ||
token_ids = output.token_ids | ||
top_logprobs = output.logprobs | ||
|
||
if request.logprobs is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
logprobs = create_logprobs_fn( | ||
token_ids=token_ids, | ||
top_logprobs=top_logprobs, | ||
num_output_top_logprobs=request.logprobs, | ||
) | ||
else: | ||
logprobs = None | ||
|
||
choice_data = ChatCompletionResponseChoice( | ||
index=output.index, | ||
message=ChatMessage(role=role, content=output.text), | ||
logprobs=logprobs, | ||
finish_reason=output.finish_reason, | ||
) | ||
choices.append(choice_data) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this should be bool type in chat completion. And we should make sampling params accept both.