Skip to content

Commit

Permalink
Add LogProbs for Chat Completions in OpenAI (vllm-project#2918)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlcmoore authored Feb 26, 2024
1 parent fe671f4 commit f783298
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
25 changes: 13 additions & 12 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand Down Expand Up @@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
single_output = single_completion.choices[0].text
single_usage = single_completion.usage

stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
)
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
Expand Down
8 changes: 8 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand All @@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
length_penalty: Optional[float] = 1.0

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
Expand All @@ -96,6 +100,8 @@ def to_sampling_params(self) -> SamplingParams:
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
Expand Down Expand Up @@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
38 changes: 36 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ async def chat_completion_stream_generator(
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
Expand All @@ -118,6 +121,7 @@ async def chat_completion_stream_generator(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]

if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
Expand All @@ -129,6 +133,7 @@ async def chat_completion_stream_generator(
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand All @@ -145,15 +150,29 @@ async def chat_completion_stream_generator(
if finish_reason_sent[i]:
continue

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)

if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -174,6 +193,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand Down Expand Up @@ -208,11 +228,25 @@ async def chat_completion_full_generator(
assert final_res is not None

choices = []

role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs

if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None

choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
Expand Down

0 comments on commit f783298

Please sign in to comment.