Skip to content

Commit

Permalink
[Core] Pipeline Parallel Support (vllm-project#4412)
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
andoorve authored and Alvant committed Oct 26, 2024
1 parent 0c065f4 commit e664e07
Show file tree
Hide file tree
Showing 82 changed files with 1,100 additions and 404 deletions.
10 changes: 10 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py

- label: Pipeline Parallelism Test
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py


- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
Expand Down
14 changes: 13 additions & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine

from ..utils import wait_for_gpu_memory_to_clear
Expand All @@ -23,15 +24,21 @@ def __init__(self):
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)

async def step_async(self):
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []

async def process_model_inputs_async(self, *args, **kwargs):
pass

async def stop_remote_worker_execution_loop_async(self):
pass

def generate(self, request_id):
self.request_id = request_id

Expand All @@ -41,6 +48,7 @@ def stop_generating(self):
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')

async def add_request_async(self, **kwargs):
self.add_request_calls += 1
Expand All @@ -53,6 +61,9 @@ def abort_request(self, request_id):
def has_unfinished_requests(self):
return self.request_id is not None

def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None


class MockAsyncLLMEngine(AsyncLLMEngine):

Expand All @@ -76,6 +87,7 @@ async def test_new_requests_event():
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
Expand Down
4 changes: 2 additions & 2 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# and debugging.
import ray

from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"


@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()

Expand Down
24 changes: 12 additions & 12 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_chunked_prefill_recompute(
max_num_seqs=max_num_seqs,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
Expand Down Expand Up @@ -91,10 +91,10 @@ def test_preemption(
disable_log_stats=False,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)

check_outputs_equal(
outputs_0_lst=hf_outputs,
Expand Down Expand Up @@ -147,10 +147,10 @@ def test_swap(
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)

for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
Expand Down Expand Up @@ -214,8 +214,8 @@ def test_swap_infeasible(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
Expand Down Expand Up @@ -252,8 +252,8 @@ def test_preemption_infeasible(
sampling_params=sampling_params,
)

assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

# Verify the request is ignored and not hang.
for req_output in req_outputs:
Expand Down
20 changes: 17 additions & 3 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
(r + 1) for r in range(tp_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)

Expand Down Expand Up @@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
for r in range(tp_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)

Expand Down Expand Up @@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}

if rank == 0:
if (rank % tp_size) == 0:
broadcast_tensor_dict(test_dict, src=0)
else:
recv_dict = broadcast_tensor_dict(src=0)
Expand Down Expand Up @@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)
149 changes: 149 additions & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os

import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray

from ..utils import VLLM_PATH, RemoteOpenAIServer

# downloading lora to test lora requests

# any model with a chat template should work here
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
TP_SIZE = int(os.getenv("TP_SIZE", 1))
PP_SIZE = int(os.getenv("PP_SIZE", 1))

pytestmark = pytest.mark.asyncio


@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()


@pytest.fixture(scope="module")
def server(ray_ctx):
args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
]
if CHUNKED_PREFILL:
args += [
"--enable-chunked-prefill",
]
if EAGER_MODE:
args += [
"--enforce-eager",
]
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)


@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()


async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)

# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5


@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"

# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
8 changes: 4 additions & 4 deletions tests/engine/output_processor/test_multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
Expand Down
Loading

0 comments on commit e664e07

Please sign in to comment.