Skip to content

Commit

Permalink
Push logprob generation to LLMEngine (vllm-project#3065)
Browse files Browse the repository at this point in the history
Co-authored-by: Avnish Narayan <[email protected]>
  • Loading branch information
2 people authored and dbogunowicz committed Mar 26, 2024
1 parent 304abaa commit b00c20d
Show file tree
Hide file tree
Showing 13 changed files with 555 additions and 335 deletions.
61 changes: 58 additions & 3 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10)
top_logprobs=5)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand All @@ -229,14 +229,69 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

# Default max_logprobs is 5, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=False)

with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)

# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
stream=False)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
Expand Down
42 changes: 36 additions & 6 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from tests.conftest import VllmRunner

from vllm import SamplingParams

Expand All @@ -16,41 +17,70 @@ def test_get_prompt_logprobs(
example_prompts,
):
max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5,
logprobs=num_top_logprobs,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
del vllm_model

# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob,
torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items():
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \
("The token should be decoded by the time it is returned "
" to the user.")


def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)

bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
12 changes: 7 additions & 5 deletions tests/worker/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(


def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]],
expected_logprobs: List[Dict[int, float]]) -> None:
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id])
expected = torch.tensor(single_step_expected_logprobs[token_id])
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected)
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -93,6 +94,7 @@ def __init__(
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs

if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class EngineArgs:
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
Expand Down Expand Up @@ -212,6 +213,12 @@ def add_cli_args(
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
Expand Down Expand Up @@ -300,7 +307,8 @@ def create_engine_configs(
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
29 changes: 24 additions & 5 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, request_id: str) -> None:
self._queue = asyncio.Queue()
self._finished = False

def put(self, item: RequestOutput) -> None:
def put(self, item: Union[RequestOutput, Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
Expand Down Expand Up @@ -110,6 +110,17 @@ def process_request_output(self,
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)

def process_exception(self,
request_id: str,
exception: Exception,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)

def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
Expand Down Expand Up @@ -377,10 +388,18 @@ async def engine_step(self) -> bool:
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
await self.engine.add_request_async(**new_request)
try:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)

if finished_requests:
await self._engine_abort(finished_requests)
Expand Down
42 changes: 40 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
Expand Down Expand Up @@ -476,6 +476,13 @@ def add_request(
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.monotonic()
prompt_token_ids = self.encode_request(
Expand Down Expand Up @@ -586,6 +593,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
seq_group.prompt_logprobs = prompt_logprobs

# Process samples
Expand Down Expand Up @@ -933,12 +947,36 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)

(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
Expand Down
Loading

0 comments on commit b00c20d

Please sign in to comment.