Skip to content

Commit

Permalink
feat: support prompt_logprobs output with spec decoding
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 committed Sep 5, 2024
1 parent 74196ae commit 84490e7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 18 deletions.
11 changes: 11 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,17 @@ def __eq__(self, other: object):
self.__class__) and self.outputs == other.outputs


def get_all_seq_data_entries(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[Tuple[int, SequenceData]]:
"""Given a list of SequenceGroupMetadata, create a dict of
sequence ids to SequenceData
"""
return [(seq_id, seq_data) for sg in seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items()
]


def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all
Expand Down
37 changes: 27 additions & 10 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids, get_all_seq_ids_and_request_ids)
get_all_seq_data_entries,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
Expand All @@ -30,6 +31,7 @@
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
create_logprobs_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
Expand Down Expand Up @@ -439,8 +441,8 @@ def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Expand All @@ -452,19 +454,34 @@ def _serialize_sampler_output_no_logprobs(
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
`CompletionSequenceGroupOutput` objects with only token IDs
populated.
"""
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
has_prompt = any(seq.is_prompt
for seq in execute_model_req.seq_group_metadata_list)
sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
if any(seq.is_prompt
for seq in execute_model_req.seq_group_metadata_list) else \
if has_prompt else \
sampler_output.sampled_token_ids).tolist()

seq_data_entries = get_all_seq_data_entries(
execute_model_req.seq_group_metadata_list)
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids):
for index, (seq_id, seq_data) in enumerate(seq_data_entries):
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_logprobs = [create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
)
# no logprobs for the first token
for p_token_id in prompt_token_ids[1:]] \
if prompt_token_ids is not None else None

completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
Expand All @@ -473,7 +490,7 @@ def _serialize_sampler_output_no_logprobs(
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
))
prompt_logprobs=prompt_logprobs))
return SamplerOutput(outputs=completion_seq_group_output_list)

@nvtx_range("spec_decode_worker._run_no_spec")
Expand Down
45 changes: 37 additions & 8 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceGroupMetadata, SequenceOutput)
PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)

SeqId = int

Expand Down Expand Up @@ -49,21 +50,19 @@ def get_sampled_token_logprobs(
return sampled_token_ids_ranks, selected_logprobs


def create_sequence_group_output(
def create_logprobs_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
) -> Dict[int, Logprob]:
"""Create a Logprob Dict for a token given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
Expand All @@ -85,14 +84,44 @@ def create_sequence_group_output(
if topk_token_id is not None
})

return logprobs


def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""

logprobs = create_logprobs_output(
token_id,
token_id_logprob_rank,
token_id_logprob,
topk_token_ids,
topk_logprobs,
)

return CompletionSequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
# TODO add prompt logprobs support.
prompt_logprobs=None,
prompt_logprobs=prompt_logprobs,
)


Expand Down

0 comments on commit 84490e7

Please sign in to comment.