Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push logprob generation to LLMEngine #3065

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from tests.conftest import VllmRunner

from vllm import SamplingParams

Expand All @@ -16,41 +17,72 @@ def test_get_prompt_logprobs(
example_prompts,
):
max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model,
dtype=dtype,
max_log_probs=num_top_logprobs)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5,
logprobs=num_top_logprobs,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
del vllm_model

# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob,
torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items():
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \
("The token should be decoded by the time it is returned "
" to the user.")


def test_max_log_probs():
runner = VllmRunner("facebook/opt-125m", max_log_probs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)

bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
12 changes: 7 additions & 5 deletions tests/worker/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData
from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(


def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]],
expected_logprobs: List[Dict[int, float]]) -> None:
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id])
expected = torch.tensor(single_step_expected_logprobs[token_id])
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected)
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_log_probs: int = 5,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -93,6 +94,7 @@ def __init__(
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_log_probs = max_log_probs

if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class EngineArgs:
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_log_probs: int = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is max_logprobs better than this?
And we can add comments for why default value is 5 (from OpenAI API Reference?).

disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
Expand Down Expand Up @@ -201,6 +202,12 @@ def add_cli_args(
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument(
'--max-log-probs',
type=int,
default=EngineArgs.max_log_probs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
Expand Down Expand Up @@ -291,7 +298,8 @@ def create_engine_configs(
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
self.enforce_eager, self.max_context_len_to_capture,
self.max_log_probs)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
Expand Down
44 changes: 42 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
Expand Down Expand Up @@ -449,6 +449,13 @@ def add_request(
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
max_log_probs = self.get_model_config().max_log_probs
if (sampling_params.logprobs
and sampling_params.logprobs > max_log_probs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_log_probs):
raise ValueError(f"Cannot request more than "
f"{max_log_probs} logprobs.")
if arrival_time is None:
arrival_time = time.monotonic()
prompt_token_ids = self.encode_request(
Expand All @@ -460,6 +467,8 @@ def add_request(
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
assert prompt
assert prompt_token_ids
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)

Expand Down Expand Up @@ -563,6 +572,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
seq_group.prompt_logprobs = prompt_logprobs

# Process samples
Expand Down Expand Up @@ -909,12 +925,36 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without knowing the OpenAI behaviour, IMHO it would be more appropriate here to use convert_ids_to_tokens and include the explicit/atomic token strings. Otherwise the text may not line up with the token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the OpenAI behavior is exactly that - the logprob token text depends on the previous tokens and is not constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 ok, thanks! I'll take a closer look at this.

self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)

(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ async def create_completion(self, request: CompletionRequest,
request, prompt=prompt)

generators.append(
self.engine.generate(None,
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
Expand Down
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ModelCard, ModelList,
ModelPermission)
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob

logger = init_logger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@ async def show_available_models(self) -> ModelList:
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
Expand All @@ -95,10 +96,10 @@ def _create_logprobs(
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
token_logprob = step_top_logprobs[token_id].logprob
else:
token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id)
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
Expand All @@ -110,7 +111,7 @@ def _create_logprobs(

if num_output_top_logprobs:
logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput)
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)


class Sampler(nn.Module):
Expand Down Expand Up @@ -520,7 +521,10 @@ def _get_logprobs(
prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist()))
group_prompt_logprobs.append(prompt_logprobs_dict)
group_prompt_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in prompt_logprobs_dict.items()
})
sample_idx += 1
query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs)
Expand All @@ -545,7 +549,10 @@ def _get_logprobs(
parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist()))
group_sample_logprobs.append(sample_logprobs_dict)
group_sample_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in sample_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)

Expand Down
26 changes: 19 additions & 7 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest

PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]

@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs."""
logprob: float
decoded_token: Optional[str] = None


PromptLogprobs = List[Optional[Dict[int, Logprob]]]
SampleLogprobs = List[Dict[int, Logprob]]


class SequenceStatus(enum.Enum):
Expand Down Expand Up @@ -187,12 +195,12 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id])
self.data.append_token_id(token_id, logprobs[token_id].logprob)

def get_len(self) -> int:
return self.data.get_len()
Expand Down Expand Up @@ -465,9 +473,13 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutput):
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
equal = (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token)
log_probs_equal = ((len(other.logprobs) == len(self.logprobs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to move this to Logprobs's __eq__?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call!

and all(other_logprob == self_logprob
for other_logprob, self_logprob in zip(
other.logprobs, self.logprobs)))
return equal and log_probs_equal


class SequenceGroupOutput:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _append_new_tokens(
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]

seq.append_token_id(token_id, token_logprob)
seq.append_token_id(token_id, token_logprob.logprob)

def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
Expand Down
Loading