diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index be5f4190e633f..8dc2e58e4709d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -272,6 +272,8 @@ async def create_chat_completion(request: ChatCompletionRequest, top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, + logprobs=request.logprobs, + prompt_logprobs=request.logprobs if request.echo else None, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) @@ -292,7 +294,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: role = get_role() for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=DeltaMessage(role=role), finish_reason=None) + index=i, + delta=DeltaMessage(role=role), + logprobs=(LogProbs() + if request.logprobs is not None else None), + finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, object=chunk_object_type, created=created_time, @@ -314,6 +320,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=last_msg_content), + logprobs=(LogProbs() + if request.logprobs is not None else None), finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, @@ -339,11 +347,24 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if output.finish_reason is None: # Send token-by-token response for each request.n delta_text = output.text[len(previous_texts[i]):] + if request.logprobs is not None: + token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[i]:] + offsets = len(previous_texts[i]) + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) + else: + logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), + logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, @@ -362,7 +383,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: total_tokens=prompt_tokens + previous_num_tokens[i], ) choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=[], finish_reason=output.finish_reason) + index=i, + delta=[], + logprobs=(LogProbs() + if request.logprobs is not None else None), + finish_reason=output.finish_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -391,11 +416,27 @@ async def completion_full_generator(): assert final_res is not None choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs role = get_role() for output in final_res.outputs: + if request.logprobs is not None: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None choice_data = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role=role, content=output.text), + logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7a86a19c4bf80..e08cc6d144f93 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -61,6 +61,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + logprobs: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -155,6 +156,7 @@ class ChatMessage(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None @@ -175,6 +177,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None