Skip to content

Commit

Permalink
[Speculative decoding] Improve n-gram efficiency (vllm-project#4724)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored May 13, 2024
1 parent 8bc68e1 commit ce532ff
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 58 deletions.
17 changes: 9 additions & 8 deletions tests/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match():
max_proposal_len=20,
)

# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)

prompts = [
# shall find no candidate
Expand Down Expand Up @@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
max_proposal_len=20,
)

# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)

prompts = [
# shall find no candidate
Expand Down Expand Up @@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])

# the first sequence has no match so proposal_len should be overwritten to 0
assert proposals.proposal_lens.tolist(
) == [proposal_len for _ in range(4)] + [0]
) == [0] + [proposal_len for _ in range(3)] + [0]

for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == 0
assert proposals.proposal_token_ids[0][i] == -1
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
Expand Down Expand Up @@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
max_proposal_len=20,
)

# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
# set ngram window [0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)

prompts = [
# shall find candidate 12,13,14,15,16
Expand Down
13 changes: 8 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,12 +784,15 @@ def maybe_create_spec_config(
draft_quantization = None

if speculative_model == "[ngram]":
assert (ngram_prompt_lookup_max is not None
and ngram_prompt_lookup_max > 0)
if ngram_prompt_lookup_min is None:
ngram_prompt_lookup_min = 0
else:
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
ngram_prompt_lookup_min = 1
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
if ngram_prompt_lookup_min < 1:
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
f"larger than {ngram_prompt_lookup_max=}")

# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
Expand Down
97 changes: 52 additions & 45 deletions vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def sampler_output(
"""
self._raise_if_unsupported(execute_model_req)

arr = []
has_spec_out = False
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
token_id_list = []
token_prob_list = []
for idx, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))

input_ids = torch.as_tensor(seq_data.get_token_ids(),
Expand All @@ -89,59 +91,64 @@ def sampler_output(

for ngram_size in range(
min(self.ngram_prompt_lookup_max, input_length - 1),
self.ngram_prompt_lookup_min,
self.ngram_prompt_lookup_min - 1,
-1,
):
ngram_tensor = input_ids[-1 * ngram_size:]
windows = input_ids.unfold(dimension=0,
size=ngram_size,
step=1)
matches = (windows == ngram_tensor).all(dim=1)
match_indices = matches.nonzero(as_tuple=True)[0]
if match_indices.size()[0] > 1:
ngram_tensor = input_ids[-ngram_size:]
proposal_start_idx = None
if ngram_size == 1:
# Do not match itself and do not use unfold and all
matches = (input_ids[:-1] == ngram_tensor)
else:
windows = input_ids.unfold(dimension=0,
size=ngram_size,
step=1)
# Do not match itself
matches = (windows[:-1] == ngram_tensor).all(dim=-1)

# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# of the first match.
# Note that "first_match.values.item()" triggers GPU-CPU
# sync so it is a bit inefficient, but we have not found
# a better way to do this.
first_match = matches.max(dim=-1)
if first_match.values.item():
proposal_start_idx = first_match.indices.add_(ngram_size)
spec_indices = (
proposal_start_idx).repeat(sample_len) + torch.arange(
sample_len, device=self.device)
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
res = input_ids.gather(dim=-1, index=spec_indices)
token_id_list.append(res)
token_prob_list.append(
torch.nn.functional.one_hot(
res,
num_classes=self.vocab_size).to(torch.float32))
has_spec_out = True
res = seq_data.get_token_ids()
res = res[match_indices[0] + ngram_size:match_indices[0] +
ngram_size + sample_len]
res_len = len(res)
# pad 0 towards output as sample_len tokens required
res += [0] * (sample_len - res_len)

break
else:
# if no candidate found, fill with 0
res = [0] * sample_len

arr.append(res)
token_id_list.append(None)
token_prob_list.append(None)

if not has_spec_out:
return None, False

outputs = []
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
indices = token_ids.unsqueeze(2)
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(execute_model_req.seq_group_metadata_list)):
if token_id_list[idx] is None:
outputs.append(None)
else:
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx],
logprobs=torch.zeros((sample_len, self.vocab_size),
dtype=torch.float32,
device=self.device),
sampled_token_ids=token_id_list[idx],
))

token_probs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(execute_model_req.seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
logprobs=token_logprobs[i],
sampled_token_ids=token_ids[i],
))
return outputs, False

def get_spec_proposals(
Expand Down
63 changes: 63 additions & 0 deletions vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def get_proposals(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
)
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
transposed)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
Expand Down Expand Up @@ -140,6 +148,61 @@ def _split_by_proposal_len(
nonzero_proposal_len_indices,
)

def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""

# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if maybe_sampler_output is None or transposed:
return (proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices)

new_proposal_lens: List[int] = []
new_nonzero_proposal_len_indices: List[int] = []
new_maybe_sampler_output: List[SamplerOutput] = []
nonzero_proposal_len_idx_ptr = 0
seq_idx = 0
while seq_idx < len(
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
nonzero_proposal_len_indices):
if seq_idx < nonzero_proposal_len_indices[
nonzero_proposal_len_idx_ptr]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert proposal_lens[seq_idx] == 0
new_proposal_lens.append(0)
else:
# Sequence is in the original nonzero_proposal_len_indices
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
# but does not have a proposal from the draft worker.
new_proposal_lens.append(0)
else:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens.append(proposal_lens[seq_idx])
new_nonzero_proposal_len_indices.append(seq_idx)
new_maybe_sampler_output.append(
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
nonzero_proposal_len_idx_ptr += 1
seq_idx += 1

# The remaining sequences should have proposal length of 0.
new_proposal_lens.extend(proposal_lens[seq_idx:])

# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert new_maybe_sampler_output
return (new_proposal_lens, new_maybe_sampler_output,
new_nonzero_proposal_len_indices)

def _merge_outputs(
self,
batch_size: int,
Expand Down

0 comments on commit ce532ff

Please sign in to comment.