Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bug fix][Core] fixup ngram not setup correctly (vllm-project#4551)
Browse files Browse the repository at this point in the history
Co-authored-by: Lei Wen <[email protected]>
Co-authored-by: Cade Daniel <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
4 people authored and robertgshaw2-redhat committed May 19, 2024
1 parent fd69572 commit 8673ad0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
24 changes: 18 additions & 6 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.engine_args = AsyncEngineArgs(
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
Expand All @@ -76,6 +76,8 @@ def __init__(
**kwargs,
)
self.request_counter = Counter()
self.llm_engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)

def generate(
self,
Expand All @@ -88,9 +90,6 @@ def generate(
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:

llm_engine = AsyncLLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)

if prompts is None:
raise ValueError("prompts must be provided.")
if isinstance(prompts, str):
Expand All @@ -111,8 +110,8 @@ def generate(

async def get_output(prompt, sampling_param) -> str:
request_id = random_uuid()
results_generator = llm_engine.generate(prompt, sampling_param,
request_id)
results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
Expand Down Expand Up @@ -185,12 +184,25 @@ def generator_outer():
return generator_outer


def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (not isinstance(llm, AsyncLLM)
and llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)


def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
for llm in llm_generator():
maybe_assert_ngram_worker(llm)

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _init_spec_worker(self):
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
Expand Down
14 changes: 7 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@ def create_worker(
draft_worker_kwargs,
) -> "SpecDecodeWorker":

if "ngram_prompt_lookup_max" in draft_worker_kwargs:
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
else:
ngram_prompt_lookup_max = 0
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))

if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs)
Expand All @@ -72,6 +69,9 @@ def create_worker(
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))

return SpecDecodeWorker(
proposer_worker,
scorer_worker,
Expand Down

0 comments on commit 8673ad0

Please sign in to comment.