Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogProbs for Chat Completions in OpenAI #2918

Merged
merged 18 commits into from
Feb 26, 2024
Merged

Conversation

jlcmoore
Copy link
Contributor

I added the option to request log probabilities in the chat completions endpoint, fixing issue #2276. Most of this code is simply copied over from vllm/entrypoints/openai/serving_completion.py to vllm/entrypoints/openai/serving_chat.py. I also had to update the protocol.

No tests that do not already fail on the main branch fail with this commit. (Twelve tests already fail for pytest tests/entrypoints/test_openai_server.py)

I ran ./format.sh and it reported no issues.

I personally tested the log probabilities with streaming and non streaming requests, albeit only on a llama2 model.


Non-streaming

model = "meta-llama/Llama-2-7b-chat-hf"

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
completion = client.chat.completions.create(model=model,
                                      messages= [{'role' : 'user', 'content': "Hello!"}],
                                      logprobs=5)
print("Completion result:", completion)

outputs:

Completion result: ChatCompletion(id='cmpl-38930648424340f8bbef9deddbde806d', choices=[Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=None, text_offset=[0, 1, 7, 8, 11, 12, 13, 18, 21, 26, 30, 31, 34, 40, 50, 52, 56, 61, 65, 70, 73, 79, 83, 88, 91, 96, 97], token_logprobs=[-1.1920928244535389e-07, -2.706014311115723e-05, -0.0263030007481575, -0.006746845785528421, -0.0017300175968557596, 0.0, -6.90197994117625e-05, 0.0, -2.145764938177308e-06, 0.0, -0.01528126560151577, -0.012348978780210018, 0.0, -0.001032772590406239, -0.00016556799528189003, -1.1920928244535389e-07, -4.410734163684538e-06, 0.0, 0.0, -0.0011732844868674874, -9.536738616588991e-07, 0.0, -5.960462772236497e-07, 0.0, -6.794906312279636e-06, -5.793403761344962e-05, -0.01143308263272047], tokens=['▁', '▁Hello', '!', '▁It', "'", 's', '▁nice', '▁to', '▁meet', '▁you', '.', '▁Is', '▁there', '▁something', '▁I', '▁can', '▁help', '▁you', '▁with', '▁or', '▁would', '▁you', '▁like', '▁to', '▁chat', '?', '</s>'], top_logprobs=[{'▁': -1.1920928244535389e-07, '▁▁': -16.421875, '▁Hello': -20.90625, '▁▁▁': -22.5625, 'Hello': -24.109375}, {'▁Hello': -2.706014311115723e-05, '▁Hey': -10.82815170288086, '▁Hi': -12.03127670288086, 'Hello': -14.04690170288086, '▁hello': -14.81252670288086}, {'!': -0.0263030007481575, '▁there': -3.6513030529022217, '▁There': -13.8075532913208, ',': -14.0263032913208, '!)': -14.6981782913208}, {'▁It': -0.006746845785528421, '▁*': -5.287996768951416, '▁Nice': -6.866121768951416, '▁': -7.912996768951416, '▁:': -8.616122245788574}, {"'": -0.0017300175968557596, '▁is': -6.361104965209961, '▁nice': -14.892354965209961, '’': -15.361104965209961, '▁seems': -15.501729965209961}, {'s': 0.0, 'S': -18.421875, '▁s': -18.4296875, 'm': -18.9765625, 'll': -19.0546875}, {'▁nice': -6.90197994117625e-05, '▁great': -9.593818664550781, '▁lov': -14.125068664550781, '▁Nice': -16.68756866455078, '▁good': -17.03131866455078}, {'▁to': 0.0, '▁meeting': -22.453125, '▁and': -25.859375, '▁talking': -26.359375, '▁meet': -26.390625}, {'▁meet': -2.145764938177308e-06, '▁connect': -13.750001907348633, '▁hear': -14.640626907348633, '▁see': -15.390626907348633, '▁chat': -15.515626907348633}, {'▁you': 0.0, 'you': -20.78125, '▁You': -20.84375, '▁or': -22.90625, '▁(': -24.578125}, {'.': -0.01528126560151577, '!': -4.2027812004089355, '▁:': -8.968406677246094, '▁': -9.765281677246094, ',': -10.655906677246094}, {'▁Is': -0.012348978780210018, '▁How': -5.0123491287231445, '▁Can': -5.6998491287231445, '▁Could': -6.9498491287231445, '▁Hello': -8.215474128723145}, {'▁there': 0.0, '▁this': -17.125, 'there': -18.109375, '▁There': -19.6875, '▁here': -20.546875}, {'▁something': -0.001032772590406239, '▁anything': -6.876032829284668, '▁Something': -16.70415687561035, 'something': -19.37603187561035, 'Something': -19.56353187561035}, {'▁I': -0.00016556799528189003, '▁you': -8.734540939331055, '▁on': -12.281415939331055, 'I': -17.343915939331055, '▁i': -18.828290939331055}, {'▁can': -1.1920928244535389e-07, '▁could': -15.953125, 'can': -19.234375, '▁might': -20.4375, '▁may': -22.15625}, {'▁help': -4.410734163684538e-06, '▁assist': -12.343754768371582, 'help': -16.625003814697266, '▁helps': -18.000003814697266, '▁Help': -18.781253814697266}, {'▁you': 0.0, '▁with': -20.25, 'you': -21.703125, '▁You': -26.046875, '▁yo': -29.0}, {'▁with': 0.0, 'with': -19.4375, '▁With': -21.84375, '▁avec': -23.265625, '▁wit': -25.609375}, {'▁or': -0.0011732844868674874, ',': -6.751173496246338, '?': -12.70429801940918, '▁today': -17.89179801940918, '.': -22.82929801940918}, {'▁would': -9.536738616588991e-07, '▁Would': -14.234375953674316, '▁did': -15.531250953674316, '▁are': -17.0, '▁something': -17.140625}, {'▁you': 0.0, 'you': -22.15625, '▁like': -23.3125, '▁You': -26.390625, '▁your': -30.78125}, {'▁like': -5.960462772236497e-07, '▁just': -14.625000953674316, 'like': -16.25, '▁Like': -20.0, '▁simply': -21.046875}, {'▁to': 0.0, '▁me': -22.296875, '▁a': -25.40625, '▁us': -25.671875, 'to': -26.125}, {'▁chat': -6.794906312279636e-06, 'chat': -13.046881675720215, '▁just': -13.875006675720215, '▁Ch': -13.875006675720215, '▁talk': -14.046881675720215}, {'?': -5.793403761344962e-05, '▁for': -9.8125581741333, '▁about': -13.1875581741333, '▁and': -13.7969331741333, '?)': -15.8438081741333}, {'</s>': -0.01143308263272047, '▁': -4.558308124542236, '▁I': -7.073933124542236, '▁:': -10.636432647705078, '▁Please': -11.558307647705078}]), message=ChatCompletionMessage(content="  Hello! It's nice to meet you. Is there something I can help you with or would you like to chat?", role='assistant', function_call=None, tool_calls=None))], created=23274268, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=27, prompt_tokens=11, total_tokens=38))

Streaming

model = "meta-llama/Llama-2-7b-chat-hf"

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

completion = client.chat.completions.create(
  model=model,
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"}
  ],
  stream=True,
  logprobs=True,
  top_logprobs=5,
)

for chunk in completion:
  print(chunk.choices[0])

Outputs:

Choice(delta=ChoiceDelta(content=None, function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=0, logprobs=None)
Choice(delta=ChoiceDelta(content=' ', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'▁': 0.0}]))
Choice(delta=ChoiceDelta(content=' Hello', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'▁Hello': -0.00013171759201213717}]))
Choice(delta=ChoiceDelta(content=' there', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'▁there': -0.05832238495349884}]))
Choice(delta=ChoiceDelta(content='!', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'!': -2.145764938177308e-06}]))
Choice(delta=ChoiceDelta(content=' *', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'▁*': -0.10629353672266006}]))
Choice(delta=ChoiceDelta(content='ad', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'ad': -0.3723995089530945}]))
Choice(delta=ChoiceDelta(content='just', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=ChoiceLogprobs(content=None, top_logprobs=[{'just': 0.0}]))
...

@@ -147,13 +155,37 @@ async def chat_completion_stream_generator(

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this line down.

@@ -147,13 +155,37 @@ async def chat_completion_stream_generator(

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good call. I just added that in this commit

@ywang96
Copy link
Member

ywang96 commented Feb 20, 2024

Could you fix the formatting issue from yapf? The code itself looks good to me.

cc @simon-mo if you also want to review and approve (this should be a quick one since it simply adds the same LogProbs logics from completions API)

@jlcmoore
Copy link
Contributor Author

Thanks! I just committed the yapf changes here.

@zhuohan123
Copy link
Member

@esmeetu Can you help check and merge this PR if you are available?

Copy link
Collaborator

@esmeetu esmeetu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlcmoore Left some comments and please merge the latest main branch.

@@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this should be bool type in chat completion. And we should make sampling params accept both.

@@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams:
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.logprobs,
prompt_logprobs=self.logprobs if self.echo else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prompt_logprobs be assigned by top_logprobs?

@@ -145,15 +157,39 @@ async def chat_completion_stream_generator(
if finish_reason_sent[i]:
continue

if request.echo and request.max_tokens == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's no need to consider echo prompt logprobs in streaming chunks. It's also simpler without these.

if request.logprobs is not None:
assert(top_logprobs is not None),\
"top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not invoke self._create_logprobs directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what serving_completion.py did. I can only speculate on their design decision but perhaps it was to allow for extensibility with other log probability formats in the future--not just the open ai ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to change it if you think that is best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this function is already defined in a class, so there's no need to delegate a callable function here. Finally we can delete this and relevant things and invoke parent method _create_logprobs directly.

@jlcmoore
Copy link
Contributor Author

@esmeetu I addressed your comments in the latest commits.

@simon-mo
Copy link
Collaborator

else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id,
self._create_logprobs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this since we can directly call that method.

top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if request.logprobs:

token_ids = output.token_ids
top_logprobs = output.logprobs

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

top_logprobs = output.logprobs

if request.logprobs is not None:
assert(top_logprobs is not None),\
Copy link
Collaborator

@esmeetu esmeetu Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to move this assertion to protocol.py using validator or something else. Return 400 when this condition isn't met.

@jlcmoore
Copy link
Contributor Author

@esmeetu I addressed your comments in the latest. @simon-mo I added some tests to existing test cases.
In the latest build attached you can see that the chat test cases are not failing, although for some unrelated reasons the completions test cases are still failing

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 23, 2024

@jlcmoore LGTM! Please remove unused code and format files before merge.

@Yard1
Copy link
Collaborator

Yard1 commented Feb 23, 2024

Unrelated to this PR, but I think we should push the logprob creation to the engine. The OpenAI server should not need to do this.

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 24, 2024

Unrelated to this PR, but I think we should push the logprob creation to the engine. The OpenAI server should not need to do this.

Yeah, agree with you. Current implementation is a temporary solution which already exists but indeed need migrate it to engine layer in the future.

@jlcmoore
Copy link
Contributor Author

@esmeetu sounds great. Just waiting for the approval!

vllm/entrypoints/openai/serving_chat.py Outdated Show resolved Hide resolved
@esmeetu
Copy link
Collaborator

esmeetu commented Feb 25, 2024

@jlcmoore For the CI tests fails, please use logprobs parameter instead of top_logprobs in test_completion_streaming test, because openai doesn't have that api in completion.
And there are some warnings in chat streaming using logprobs:

  Expected `int` but got `str` - serialized value may not be as expected
  Expected `int` but got `str` - serialized value may not be as expected
  Expected `int` but got `str` - serialized value may not be as expected

Could you try to resolve these?

@jlcmoore
Copy link
Contributor Author

@esmeetu Good points! But I feel like that is out of scope here. There is significant work that needs to be done for the completions endpoint to clean it up--logprobs included. We could open a new issue and pull request for those?

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 25, 2024

@esmeetu Good points! But I feel like that is out of scope here. There is significant work that needs to be done for the completions endpoint to clean it up--logprobs included. We could open a new issue and pull request for those?

Fine. Could you fix CI failed tests? I meant that you can revert completion endpoint changes in this PR and let CI pass.

@jlcmoore
Copy link
Contributor Author

@esmeetu Happy to do so and sorry if I'm being daft here but I don't see where I have changed any of the completions endpoints.

temperature=0.0,
stream=True,
logprobs=True,
top_logprobs=10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_logprobs doesn't exist in that api. Did you pass this test on your machine? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry. fixed!

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 26, 2024

@jlcmoore No worries☺️, but it seems not ok now. And you can see error log here: https://buildkite.com/vllm/ci/builds/1667#018de10e-6152-4c86-80ba-030c56d8fe7c . But i suggest you could run and pass the test code locally at first.

@esmeetu esmeetu merged commit 70f3e8e into vllm-project:main Feb 26, 2024
22 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
@Skytliang
Copy link

Excuse me, can I get prompt_logprobs from chat.completions ?

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
completion = client.chat.completions.create(model=model,
                                      messages= [{'role' : 'user', 'content': "Hello!"}],
                                      logprobs=True, 
                                      **prompt_logprobs=True**, 
                                      top_logprobs=5)
print("Completion result:", completion)
TypeError: Completions.create() got an unexpected keyword argument 'prompt_logprobs'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants