-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LogProbs for Chat Completions in OpenAI #2918
Conversation
@@ -147,13 +155,37 @@ async def chat_completion_stream_generator( | |||
|
|||
delta_text = output.text[len(previous_texts[i]):] | |||
previous_texts[i] = output.text | |||
previous_num_tokens[i] = len(output.token_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just moved this line down.
@@ -147,13 +155,37 @@ async def chat_completion_stream_generator( | |||
|
|||
delta_text = output.text[len(previous_texts[i]):] | |||
previous_texts[i] = output.text |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to move this line down as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes good call. I just added that in this commit
Could you fix the formatting issue from cc @simon-mo if you also want to review and approve (this should be a quick one since it simply adds the same LogProbs logics from |
Thanks! I just committed the |
@esmeetu Can you help check and merge this PR if you are available? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jlcmoore Left some comments and please merge the latest main branch.
vllm/entrypoints/openai/protocol.py
Outdated
@@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel): | |||
max_tokens: Optional[int] = None | |||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) | |||
stream: Optional[bool] = False | |||
logprobs: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this should be bool type in chat completion. And we should make sampling params accept both.
vllm/entrypoints/openai/protocol.py
Outdated
@@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams: | |||
stop=self.stop, | |||
stop_token_ids=self.stop_token_ids, | |||
max_tokens=self.max_tokens, | |||
logprobs=self.logprobs, | |||
prompt_logprobs=self.logprobs if self.echo else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should prompt_logprobs
be assigned by top_logprobs
?
@@ -145,15 +157,39 @@ async def chat_completion_stream_generator( | |||
if finish_reason_sent[i]: | |||
continue | |||
|
|||
if request.echo and request.max_tokens == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's no need to consider echo prompt logprobs in streaming chunks. It's also simpler without these.
if request.logprobs is not None: | ||
assert(top_logprobs is not None),\ | ||
"top_logprobs must be provided when logprobs is requested" | ||
logprobs = create_logprobs_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not invoke self._create_logprobs
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what serving_completion.py
did. I can only speculate on their design decision but perhaps it was to allow for extensibility with other log probability formats in the future--not just the open ai ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to change it if you think that is best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this function is already defined in a class, so there's no need to delegate a callable function here. Finally we can delete this and relevant things and invoke parent method _create_logprobs
directly.
@esmeetu I addressed your comments in the latest commits. |
If you can add a test case in https://github.com/vllm-project/vllm/blob/6f32cddf1c795e74a47e84620462431154718f49/tests/entrypoints/test_openai_server.py that would be great! |
else: | ||
return await self.chat_completion_full_generator( | ||
request, raw_request, result_generator, request_id) | ||
request, raw_request, result_generator, request_id, | ||
self._create_logprobs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this since we can directly call that method.
top_logprobs = output.logprobs[ | ||
previous_num_tokens[i]:] if output.logprobs else None | ||
|
||
if request.logprobs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if request.logprobs:
token_ids = output.token_ids | ||
top_logprobs = output.logprobs | ||
|
||
if request.logprobs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
top_logprobs = output.logprobs | ||
|
||
if request.logprobs is not None: | ||
assert(top_logprobs is not None),\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to move this assertion to protocol.py
using validator or something else. Return 400 when this condition isn't met.
@esmeetu I addressed your comments in the latest. @simon-mo I added some tests to existing test cases. |
@jlcmoore LGTM! Please remove unused code and format files before merge. |
Unrelated to this PR, but I think we should push the logprob creation to the engine. The OpenAI server should not need to do this. |
Yeah, agree with you. Current implementation is a temporary solution which already exists but indeed need migrate it to engine layer in the future. |
@esmeetu sounds great. Just waiting for the approval! |
Co-authored-by: Roy <[email protected]>
@jlcmoore For the CI tests fails, please use
Could you try to resolve these? |
@esmeetu Good points! But I feel like that is out of scope here. There is significant work that needs to be done for the completions endpoint to clean it up--logprobs included. We could open a new issue and pull request for those? |
Fine. Could you fix CI failed tests? I meant that you can revert completion endpoint changes in this PR and let CI pass. |
@esmeetu Happy to do so and sorry if I'm being daft here but I don't see where I have changed any of the completions endpoints. |
temperature=0.0, | ||
stream=True, | ||
logprobs=True, | ||
top_logprobs=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top_logprobs
doesn't exist in that api. Did you pass this test on your machine? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry. fixed!
@jlcmoore No worries |
Excuse me, can I get prompt_logprobs from chat.completions ?
|
I added the option to request log probabilities in the chat completions endpoint, fixing issue #2276. Most of this code is simply copied over from
vllm/entrypoints/openai/serving_completion.py
tovllm/entrypoints/openai/serving_chat.py
. I also had to update the protocol.No tests that do not already fail on the main branch fail with this commit. (Twelve tests already fail for
pytest tests/entrypoints/test_openai_server.py
)I ran
./format.sh
and it reported no issues.I personally tested the log probabilities with streaming and non streaming requests, albeit only on a llama2 model.
Non-streaming
outputs:
Streaming
Outputs: