Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Multi Step Scheduling #7000

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Multi-step Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/
- tests/multi_step/test_correctness.py
commands:
- pytest -v -s multi_step/test_correctness.py

- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
Empty file added tests/multi_step/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions tests/multi_step/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Test the AsyncLLMEngine with multi-step-decoding

from typing import List

import pytest

from ..utils import RemoteOpenAIServer

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]

DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.85",
"--swap-space",
"16",
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
Copy link
Contributor

@afeldman-nm afeldman-nm Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SolitaryThinker no need to block on this feedback - but if you have time - I would propose adding an example/offline_inference_multi_step.py example which instantiates an engine instance with multi-step enabled. Similar in structure to example/offline_inference.py.

An example of why this is useful - as part of the logprobs workstream, I am trying to step through the multi-step model runner with the python debugger & examine the output logprobs. I am using your multi_step/test_correctness.py in order to set up a server with multi-step enabled.

However, multi_step/test_correctness.py is an end-to-end client/server test & it is not straightforward (although technically doable) to step through the server code with the debugger because the server is in another process.

I will get around this by writing a short script which sets up an engine instance with multi-step enabled.

However, for someone else who is approaching this code for the first time, it could be helpful to have an example file (or unit test) which just sets up an engine instance with multi-step enabled and invokes inference using LLM.generate(). This could be a good way to facilitate quick debugging & also gives insight into how the server works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the offline_inference_multi_step.py script I wrote for myself to facilitate debugging, if you would like to use it.

'''
Example of setting up LLM with multi-step enabled.
In actuality, async engine would be a more sensible choice
for a real use-case. However this example is useful
for demonstration & debugging of multi-step code.
'''

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="JackFram/llama-160m",
          swap_space=16,
          tensor_parallel_size=1,
          gpu_memory_utilization=0.9,
          num_scheduler_steps=8,
          use_v2_block_manager=True,
          )
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
(2, 2),
])
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

if eager_mode:
ms_server_args.append("--enforce-eager")

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
]

ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations
77 changes: 77 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -154,3 +155,79 @@ def test_embedding_model_runner_input():
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None


def test_multi_step_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
current_step=4,
last_sampled_token_ids=torch.ones((10, 1)),
is_multi_step=True,
num_queries=8,
num_seqs=5,
cached_outputs=[],
)

assert isinstance(model_input, StatefulModelInput)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input, StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None

# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
assert (received_model_input.is_first_multi_step ==
model_input.is_first_multi_step)
assert received_model_input.current_step == model_input.current_step
assert (received_model_input.last_sampled_token_ids ==
model_input.last_sampled_token_ids).all()
assert received_model_input.is_multi_step == model_input.is_multi_step
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,12 @@ def create_engine_config(self, ) -> EngineConfig:
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)

if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
self.use_v2_block_manager = True
logger.warning(
"Enabled BlockSpaceManagerV2 because it is "
"required for multi-step (--num-scheduler-steps > 1)")

speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
Expand Down Expand Up @@ -879,7 +885,6 @@ def create_engine_config(self, ) -> EngineConfig:
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
Expand Down
Loading
Loading