-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Push logprob generation to LLMEngine #3065
Changes from 1 commit
16a49f5
3b59c01
0b7a9c9
306d3dd
cafccae
f101ef6
05fcdcc
9da7def
2c3e8da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
from vllm.logger import init_logger | ||
from vllm.outputs import RequestOutput | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, | ||
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, | ||
SequenceGroupOutput, SequenceOutput, SequenceStatus) | ||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, | ||
TokenizerGroup) | ||
|
@@ -449,6 +449,13 @@ def add_request( | |
if lora_request is not None and not self.lora_config: | ||
raise ValueError(f"Got lora_request {lora_request} but LoRA is " | ||
"not enabled!") | ||
max_log_probs = self.get_model_config().max_log_probs | ||
if (sampling_params.logprobs | ||
and sampling_params.logprobs > max_log_probs) or ( | ||
sampling_params.prompt_logprobs | ||
and sampling_params.prompt_logprobs > max_log_probs): | ||
raise ValueError(f"Cannot request more than " | ||
f"{max_log_probs} logprobs.") | ||
if arrival_time is None: | ||
arrival_time = time.monotonic() | ||
prompt_token_ids = self.encode_request( | ||
|
@@ -460,6 +467,8 @@ def add_request( | |
# Create the sequences. | ||
block_size = self.cache_config.block_size | ||
seq_id = next(self.seq_counter) | ||
assert prompt | ||
assert prompt_token_ids | ||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, | ||
lora_request) | ||
|
||
|
@@ -563,6 +572,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, | |
# Process prompt logprobs | ||
prompt_logprobs = outputs.prompt_logprobs | ||
if prompt_logprobs is not None: | ||
# We can pick any sequence for the prompt. | ||
seq = next(iter(seq_group.seqs_dict.values())) | ||
all_token_ids = seq.get_token_ids() | ||
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs): | ||
self._decode_logprobs(seq, seq_group.sampling_params, | ||
prompt_logprobs_for_token, | ||
all_token_ids[:i]) | ||
seq_group.prompt_logprobs = prompt_logprobs | ||
|
||
# Process samples | ||
|
@@ -909,12 +925,36 @@ def _get_stats(self, | |
time_e2e_requests=time_e2e_requests, | ||
) | ||
|
||
def _decode_logprobs(self, seq: Sequence, prms: SamplingParams, | ||
logprobs: Dict[int, Logprob], | ||
all_input_ids: List[int]) -> None: | ||
if not logprobs: | ||
return | ||
for token_id, sample_logprob in logprobs.items(): | ||
if (sample_logprob.decoded_token is None and token_id != -1): | ||
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id] | ||
_, new_text, prefix_offset, read_offset = detokenize_incrementally( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without knowing the OpenAI behaviour, IMHO it would be more appropriate here to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, the OpenAI behavior is exactly that - the logprob token text depends on the previous tokens and is not constant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Yard1 ok, thanks! I'll take a closer look at this. |
||
self.get_tokenizer_for_seq(seq), | ||
all_input_ids=all_input_ids_with_logprob, | ||
prev_tokens=seq.tokens, | ||
prefix_offset=seq.prefix_offset, | ||
read_offset=seq.read_offset, | ||
skip_special_tokens=prms.skip_special_tokens, | ||
spaces_between_special_tokens=prms. | ||
spaces_between_special_tokens, | ||
) | ||
sample_logprob.decoded_token = new_text | ||
|
||
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: | ||
"""Decodes the new token for a sequence.""" | ||
all_input_ids = seq.get_token_ids() | ||
self._decode_logprobs(seq, prms, seq.output_logprobs[-1], | ||
all_input_ids) | ||
|
||
(new_tokens, new_output_text, prefix_offset, | ||
read_offset) = detokenize_incrementally( | ||
self.get_tokenizer_for_seq(seq), | ||
all_input_ids=seq.get_token_ids(), | ||
all_input_ids=all_input_ids, | ||
prev_tokens=seq.tokens, | ||
prefix_offset=seq.prefix_offset, | ||
read_offset=seq.read_offset, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,16 @@ | |
from vllm.sampling_params import SamplingParams | ||
from vllm.lora.request import LoRARequest | ||
|
||
PromptLogprobs = List[Optional[Dict[int, float]]] | ||
SampleLogprobs = List[Dict[int, float]] | ||
|
||
@dataclass | ||
class Logprob: | ||
"""Infos for supporting OpenAI compatible logprobs.""" | ||
logprob: float | ||
decoded_token: Optional[str] = None | ||
|
||
|
||
PromptLogprobs = List[Optional[Dict[int, Logprob]]] | ||
SampleLogprobs = List[Dict[int, Logprob]] | ||
|
||
|
||
class SequenceStatus(enum.Enum): | ||
|
@@ -187,12 +195,12 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: | |
def append_token_id( | ||
self, | ||
token_id: int, | ||
logprobs: Dict[int, float], | ||
logprobs: Dict[int, Logprob], | ||
) -> None: | ||
assert token_id in logprobs | ||
self._append_tokens_to_blocks([token_id]) | ||
self.output_logprobs.append(logprobs) | ||
self.data.append_token_id(token_id, logprobs[token_id]) | ||
self.data.append_token_id(token_id, logprobs[token_id].logprob) | ||
|
||
def get_len(self) -> int: | ||
return self.data.get_len() | ||
|
@@ -465,9 +473,13 @@ def __repr__(self) -> str: | |
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, SequenceOutput): | ||
raise NotImplementedError() | ||
return (self.parent_seq_id == other.parent_seq_id | ||
and self.output_token == other.output_token | ||
and self.logprobs == other.logprobs) | ||
equal = (self.parent_seq_id == other.parent_seq_id | ||
and self.output_token == other.output_token) | ||
log_probs_equal = ((len(other.logprobs) == len(self.logprobs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better to move this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call! |
||
and all(other_logprob == self_logprob | ||
for other_logprob, self_logprob in zip( | ||
other.logprobs, self.logprobs))) | ||
return equal and log_probs_equal | ||
|
||
|
||
class SequenceGroupOutput: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
max_logprobs
better than this?And we can add comments for why default value is 5 (from OpenAI API Reference?).