Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Fix echo/logprob OpenAI completion bug (vllm-project#3441)
Browse files Browse the repository at this point in the history
Co-authored-by: Dylan Hawk <[email protected]>
  • Loading branch information
2 people authored and andy-neuma committed Apr 12, 2024
1 parent 9f93716 commit 1091308
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 29 deletions.
31 changes: 31 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,5 +742,36 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI):
assert content.strip() == ground_truth


@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)

prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert (completion.choices[0].text is not None
and re.search(r"^" + prompt_text, completion.choices[0].text))
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
assert len(logprobs.tokens) > 5


if __name__ == "__main__":
pytest.main([__file__])
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ async def create_chat_completion(

request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
Expand All @@ -78,8 +79,8 @@ async def create_chat_completion(
except ValueError as e:
return self.create_error_response(str(e))

result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids,
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
# Streaming response
if request.stream:
Expand Down
15 changes: 10 additions & 5 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,24 @@ async def create_completion(self, request: CompletionRequest,

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats

generators.append(
self.engine.generate(prompt,
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
prompt_token_ids=prompt_ids,
lora_request=lora_request))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down Expand Up @@ -326,14 +327,18 @@ def request_output_to_completion_response(
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text

if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
Expand Down
47 changes: 27 additions & 20 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

from pydantic import conint

Expand Down Expand Up @@ -107,27 +107,32 @@ def _create_logprobs(
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []

for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id].logprob
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
logprobs.top_logprobs.append(None)
else:
token_logprob = None
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)

if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)

if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)

if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs

def create_error_response(
Expand Down Expand Up @@ -172,12 +177,12 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
raise ValueError("The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
Expand All @@ -195,6 +200,8 @@ def _validate_prompt_and_tokenize(
else:
input_ids = prompt_ids

input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids)

if request.max_tokens is None:
Expand All @@ -209,4 +216,4 @@ def _validate_prompt_and_tokenize(
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids
return input_ids, input_text

0 comments on commit 1091308

Please sign in to comment.