diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c65ab04b8ddda..15f971b66e3bd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -13,6 +13,9 @@ steps: - label: Basic Correctness Test command: pytest -v -s --forked basic_correctness + +- label: Core Test + command: pytest -v -s core - label: Distributed Comm Ops Test command: pytest -v -s --forked test_comm_ops.py diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py new file mode 100644 index 0000000000000..ecdf3025cffdf --- /dev/null +++ b/tests/core/test_block_manager.py @@ -0,0 +1,262 @@ +import pytest +import time +from typing import List + +from vllm import SamplingParams +from vllm.block import PhysicalTokenBlock +from vllm.core.block_manager import BlockAllocator, BlockSpaceManager, AllocStatus +from vllm.utils import Device +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus + +from .utils import create_dummy_prompt + + +def test_block_allocator_allocate(): + block_size = 4 + num_cpu_blocks = 4 + cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + + # Allocate all available cpu blocks. + num_free = num_cpu_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + for _ in range(num_cpu_blocks): + block = cpu_allocator.allocate() + num_free -= 1 + assert block not in cpu_allocator.free_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + + with pytest.raises(ValueError): + cpu_allocator.allocate() + + +def test_block_allocator_free(): + block_size = 4 + num_cpu_blocks = 4 + cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + + # Allocate all available cpu blocks. + blocks: List[PhysicalTokenBlock] = [] + for _ in range(num_cpu_blocks): + block = cpu_allocator.allocate() + blocks.append(block) + assert block not in cpu_allocator.free_blocks + + # Free all allocated cpu blocks. + num_free = 0 + assert cpu_allocator.get_num_free_blocks() == num_free + for block in blocks: + cpu_allocator.free(block) + num_free += 1 + assert block in cpu_allocator.free_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + + with pytest.raises(ValueError): + cpu_allocator.free(block) + + +def test_allocate(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks): + _, seq_group = create_dummy_prompt(str(i), block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range(num_gpu_blocks - 1): + _, seq_group = create_dummy_prompt(str(i), block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + +def test_append_slot_single_seq(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate single seq to gpu block. + prompt, seq_group = create_dummy_prompt("1", block_size) + block_manager.allocate(seq_group) + + # Nothing to append. Sequence has no new logical blocks. + assert block_manager.can_append_slot(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slot(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks == after_blocks + + # Add block_size number of new tokens and append slot. + for i in range(block_size): + token_id = i + 5 + prompt.append_token_id(token_id, {token_id: 0.0}) + + assert block_manager.can_append_slot(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slot(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks - after_blocks == 1 + + +def test_append_slot_cow(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate prompt to gpu block. + prompt = Sequence(1, "one two three", [1, 2, 3], block_size) + child = prompt.fork(2) + token_id = 4 + child.append_token_id(token_id, {token_id: 0.0}) + seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), + time.time(), time.perf_counter) + block_manager.allocate(seq_group) + + # Append slot for child token. + # Last block being modified is shared. Copy on write occurs. + assert block_manager.can_append_slot(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + src_block, dst_block = block_manager.append_slot(child) + assert src_block != dst_block + + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks - after_blocks == 1 + + +def test_fork(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", + block_size - 1, + block_size=block_size) + block_manager.allocate(seq_group) + + # Fork prompt and copy block tables. + child = prompt.fork(2) + block_manager.fork(prompt, child) + assert block_manager.get_block_table( + prompt) == block_manager.get_block_table(child) + token_id = 4 + # Append token to child. Block is shared so copy on write occurs. + child.append_token_id(token_id, {token_id: 0.0}) + block_manager.append_slot(child) + assert block_manager.get_block_table( + prompt) != block_manager.get_block_table(child) + + +def test_swap(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) + prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + prompt.status = SequenceStatus.RUNNING + prompt.append_token_id(token_id, {token_id: 0.0}) + + # Swap seq group from GPU -> CPU. + gpu_blocks = block_manager.get_block_table(prompt) + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + prompt.status = SequenceStatus.SWAPPED + + # Swap seq group from CPU -> GPU. + cpu_blocks = block_manager.get_block_table(prompt) + assert block_manager.can_swap_in(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert list(mapping.keys()) == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + +def test_free(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", block_size) + block_manager.allocate(seq_group) + + # Free allocated seq. + prompt_blocks = len(block_manager.get_block_table(prompt)) + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed seq is deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(prompt) + + +def test_reset(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks): + _, seq_group = create_dummy_prompt(str(i), block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py new file mode 100644 index 0000000000000..6322b2f2d5e9e --- /dev/null +++ b/tests/core/test_scheduler.py @@ -0,0 +1,170 @@ +from typing import List +import pytest # noqa + +from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.sequence import SequenceGroup + +from .utils import create_dummy_prompt + + +def test_scheduler_add_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1, 256) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq group to scheduler. + num_seq_group = 4 + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == i + 1 + + +def test_scheduler_abort_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1, 256) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add multiple seq groups to scheduler. + num_seq_group = 4 + request_ids = set() + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + request_ids.add(str(i)) + + # Abort all added seq groups. + assert scheduler.get_num_unfinished_seq_groups() == num_seq_group + scheduler.abort_seq_group(request_ids) + assert scheduler.get_num_unfinished_seq_groups() == 0 + + +def test_scheduler_schedule_simple(): + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( + )[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + +def test_scheduler_schedule_preempt_abort(): + block_size = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 2 + cache_config.num_gpu_blocks = 2 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + seq_a, seq_group_a = create_dummy_prompt("1", block_size) + seq_b, seq_group_b = create_dummy_prompt("2", block_size) + scheduler.add_seq_group(seq_group_a) + scheduler.add_seq_group(seq_group_b) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Append "generated" tokens, allowing the sequence to mark prompt tokens as + # processed. + token_id = 0 + seq_a.append_token_id(token_id, {token_id: 0.0}) + seq_b.append_token_id(token_id, {token_id: 0.0}) + + # Schedule seq groups generation and preempt seq group b. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a] + assert out.num_batched_tokens == 1 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Abort seq group a. Re-schedule seq group b prompt with recomputation. + scheduler.abort_seq_group("1") + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_b] + assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 1 + + +def test_scheduler_max_seqs(): + block_size = 4 + num_seq_group = 4 + max_seq_group = 2 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + all_seq_groups: List[SequenceGroup] = [] + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + all_seq_groups.append(seq_group) + + # Append 1 seq group + scheduler.add_seq_group(all_seq_groups[0]) + + # Schedule seq groups prompts. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Schedule seq groups generation. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Append 2 more seq group + scheduler.add_seq_group(all_seq_groups[1]) + scheduler.add_seq_group(all_seq_groups[2]) + + # Schedule seq groups prompts. + # Only 1 seq group should be scheduled since max_seq_group is 2 + # and one is prompting. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) diff --git a/tests/core/utils.py b/tests/core/utils.py new file mode 100644 index 0000000000000..9c0cfe1a7cf66 --- /dev/null +++ b/tests/core/utils.py @@ -0,0 +1,27 @@ +import time +from typing import Tuple + +from vllm import SamplingParams +from vllm.sequence import Sequence, SequenceGroup + + +def create_dummy_prompt( + request_id: str, + prompt_length: int, + block_size: int = None) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) + prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), + time.time(), None, None) + + return prompt, seq_group + + +def round_up_to_next_block(seq_len: int, block_size: int) -> int: + return (seq_len + block_size - 1) // block_size