Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push logprob generation to LLMEngine #3065

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10)
top_logprobs=5)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand All @@ -229,14 +229,69 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

# Default max_logprobs is 5, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=False)

with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)

# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
stream=False)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
Expand Down
42 changes: 36 additions & 6 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from tests.conftest import VllmRunner

from vllm import SamplingParams

Expand All @@ -16,41 +17,70 @@ def test_get_prompt_logprobs(
example_prompts,
):
max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5,
logprobs=num_top_logprobs,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
del vllm_model

# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob,
torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items():
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \
("The token should be decoded by the time it is returned "
" to the user.")


def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)

bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
12 changes: 7 additions & 5 deletions tests/worker/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(


def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]],
expected_logprobs: List[Dict[int, float]]) -> None:
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id])
expected = torch.tensor(single_step_expected_logprobs[token_id])
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected)
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -93,6 +94,7 @@ def __init__(
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs

if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class EngineArgs:
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
Expand Down Expand Up @@ -212,6 +213,12 @@ def add_cli_args(
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
Expand Down Expand Up @@ -300,7 +307,8 @@ def create_engine_configs(
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
29 changes: 24 additions & 5 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, request_id: str) -> None:
self._queue = asyncio.Queue()
self._finished = False

def put(self, item: RequestOutput) -> None:
def put(self, item: Union[RequestOutput, Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
Expand Down Expand Up @@ -110,6 +110,17 @@ def process_request_output(self,
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)

def process_exception(self,
request_id: str,
exception: Exception,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)

def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
Expand Down Expand Up @@ -377,10 +388,18 @@ async def engine_step(self) -> bool:
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
await self.engine.add_request_async(**new_request)
try:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)

if finished_requests:
await self._engine_abort(finished_requests)
Expand Down
42 changes: 40 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
Expand Down Expand Up @@ -473,6 +473,13 @@ def add_request(
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.monotonic()
prompt_token_ids = self.encode_request(
Expand Down Expand Up @@ -583,6 +590,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
seq_group.prompt_logprobs = prompt_logprobs

# Process samples
Expand Down Expand Up @@ -930,12 +944,36 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without knowing the OpenAI behaviour, IMHO it would be more appropriate here to use convert_ids_to_tokens and include the explicit/atomic token strings. Otherwise the text may not line up with the token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the OpenAI behavior is exactly that - the logprob token text depends on the previous tokens and is not constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 ok, thanks! I'll take a closer look at this.

self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)

(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
Expand Down
Loading
Loading