diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 6af6c8149c317..f6a542afe1a3d 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -49,7 +49,9 @@ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line echo "### Serving Benchmarks" >> benchmark_results.md sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line echo "" >> benchmark_results.md -tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines +echo '```' >> benchmark_results.md +tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines +echo '```' >> benchmark_results.md # upload the results to buildkite /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 542a51f116db2..ee384c27e1d0c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -36,6 +36,16 @@ steps: - label: Entrypoints Test command: pytest -v -s entrypoints +- label: Examples Test + working_dir: "/vllm-workspace/examples" + commands: + # install aws cli for llava_example.py + - pip install awscli + - python3 offline_inference.py + - python3 offline_inference_with_prefix.py + - python3 llm_engine_example.py + - python3 llava_example.py + - label: Kernels Test %N command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 diff --git a/README.md b/README.md index bf8fbd4173949..08e46b68cb7ce 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.) +- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) - Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) - Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 0f223571782c6..da02493b17fd3 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -26,7 +26,9 @@ def main(args: argparse.Namespace): kv_cache_dtype=args.kv_cache_dtype, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, - download_dir=args.download_dir) + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size) sampling_params = SamplingParams( n=args.n, @@ -145,6 +147,16 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument( + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 8ef6da4a6dac1..9c2f5ba458eb4 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -119,6 +119,10 @@ Alongside each architecture, we include some popular models that use it. - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. - ✅︎ + * - :code:`Qwen2MoeForCausalLM` + - Qwen2MoE + - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. + - * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. diff --git a/examples/llava_example.py b/examples/llava_example.py index a455e98585983..3d22b492654bf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -78,7 +78,13 @@ def main(args): # Make sure the local directory exists or create it os.makedirs(local_directory, exist_ok=True) - # Use AWS CLI to sync the directory - subprocess.check_call( - ["aws", "s3", "sync", s3_bucket_path, local_directory]) + # Use AWS CLI to sync the directory, assume anonymous access + subprocess.check_call([ + "aws", + "s3", + "sync", + s3_bucket_path, + local_directory, + "--no-sign-request", + ]) main(args) diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index fbfb384fd4282..7ed0563f14e0e 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -22,7 +22,7 @@ sampling_params = SamplingParams(temperature=0.0) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True) generating_prompts = [prefix + prompt for prompt in prompts] diff --git a/tests/conftest.py b/tests/conftest.py index eb5424f909fd7..770da1e6f14b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -281,6 +281,8 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -292,6 +294,8 @@ def __init__( disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index f40969cf2f3c8..88c2c37f4fb39 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,6 +10,10 @@ from .utils import create_dummy_prompt +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) @@ -57,9 +61,9 @@ def test_scheduler_schedule_simple(): cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] # Add seq groups to scheduler. - running: List[SequenceGroup] = [] for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) scheduler.add_seq_group(seq_group) @@ -68,7 +72,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups prompts. num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -76,7 +80,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups generation. seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert get_sequence_groups(out) == [seq_group_a, seq_group_b] assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups generation and preempt seq group b. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a] + assert get_sequence_groups(out) == [seq_group_a] assert out.num_batched_tokens == 1 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort(): # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_b] + assert get_sequence_groups(out) == [seq_group_b] assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -155,11 +159,11 @@ def test_scheduler_max_seqs(): # Schedule seq groups prompts. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Schedule seq groups generation. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Append 2 more seq group scheduler.add_seq_group(all_seq_groups[1]) @@ -169,7 +173,7 @@ def test_scheduler_max_seqs(): # Only 1 seq group should be scheduled since max_seq_group is 2 # and one is prompting. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) def test_scheduler_delay_factor(): diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d1811cb694db6..62085742761df 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) num_elements = 8 all_tensors = [ @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) test_dict = { "a": torch.arange(8, dtype=torch.float32, device="cuda"), diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528c..0bd3bf8837450 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(1, world_size, rank, rank, distributed_init_port) custom_ar.init_custom_ar() @@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(1, world_size, rank, rank, distributed_init_port) sz = 1024 diff --git a/tests/test_sequence.py b/tests/test_sequence.py index bb6bcddf1343e..1dec928158b16 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,7 @@ import pytest -from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) @pytest.fixture @@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs): sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) assert sampler_output1 == sampler_output2 assert sampler_output1 != sampler_output3 + + +def test_sequence_data_prefill(): + seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) + assert seq_data.get_num_uncomputed_tokens() == 4 + assert seq_data.get_num_computed_tokens() == 0 + # advance by 2 + seq_data.update_num_computed_tokens(2) + assert seq_data.get_num_uncomputed_tokens() == 2 + assert seq_data.get_num_computed_tokens() == 2 + + # advance by 1 + seq_data.update_num_computed_tokens(1) + assert seq_data.get_num_uncomputed_tokens() == 1 + assert seq_data.get_num_computed_tokens() == 3 + + # append tokens and reset, simulating recompute + seq_data.append_token_id(1, logprob=0.0) + seq_data.reset_num_computed_tokens() + assert seq_data.get_num_uncomputed_tokens() == 5 + assert seq_data.get_num_computed_tokens() == 0 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 930ecad3f1755..5b6f001f62fa7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData(seq_data)}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - )) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) expected_selected_token_indices = [] selected_token_start_idx = 0 @@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size): prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) seq_data = list(range(prompt_len)) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: SequenceData(seq_data)}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - )) + seq_data = SequenceData(seq_data) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) input_tokens, input_positions, attn_metadata, _, _, _ = ( model_runner._prepare_decode(seq_group_metadata_list)) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 90fce1a0349b2..42f4284c6c775 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -41,6 +41,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool: try: import flash_attn # noqa: F401 except ImportError: - logger.info("flash_attn is not found.") + logger.info( + "Cannot use FlashAttention because the package is not found. " + "Please install it for better performance.") return False return True diff --git a/vllm/config.py b/vllm/config.py index 5025b046fecb8..265cfa56c04fc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -533,6 +533,8 @@ class SchedulerConfig: delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. + enable_chunked_prefill: If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens. """ def __init__( @@ -542,6 +544,7 @@ def __init__( max_model_len: int, use_v2_block_manager: bool = False, delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -553,6 +556,7 @@ def __init__( self.max_model_len = max_model_len self.delay_factor = delay_factor self.use_v2_block_manager = use_v2_block_manager + self.chunked_prefill_enabled = enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index c5c8d0a05539b..160a86556f031 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -230,13 +230,12 @@ def __init__( self.watermark_blocks = int(watermark * num_gpu_blocks) if self.enable_caching: - logger.info("enable automatic prefix caching") + logger.info("Automatic prefix caching is enabled.") self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, num_cpu_blocks) else: - logger.info("disable automatic prefix caching") self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator = UncachedBlockAllocator( diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 85c2fdf75c084..04e8056aab544 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,7 @@ import enum import time from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig @@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() +# seq_group: SequenceGroup to schedule. +# token_chunk_size: The number of prefill tokens to be processed in the next +# step. +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + class SchedulerOutputs: def __init__( self, - scheduled_seq_groups: Iterable[SequenceGroup], + scheduled_seq_groups: Iterable[ScheduledSequenceGroup], prompt_run: bool, num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], @@ -39,17 +53,41 @@ def __init__( blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], ) -> None: - self.scheduled_seq_groups = scheduled_seq_groups - self.prompt_run = prompt_run - self.num_batched_tokens = num_batched_tokens - self.blocks_to_swap_in = blocks_to_swap_in - self.blocks_to_swap_out = blocks_to_swap_out - self.blocks_to_copy = blocks_to_copy + """A list of sequence groups to be scheduled as a single batch. + + Args: + scheduled_seq_groups: A tuple of scheduled sequence group and its + token chunk size. + prompt_run: True if all sequence groups are in prefill phase. + If False, all sequence groups are in decoding phase. + num_batched_tokens: Total number of batched tokens. + blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block + number. + blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block + number. + blocks_to_copy: Blocks to copy. Source to a list of dest blocks. + ignored_seq_groups: Sequence groups that are going to be ignored. + """ + # A tuple of scheduled sequence group and its chunk size. + self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups + # True if all sequence groups are in prefill phase. If False, all + # sequence groups are in decoding phase. + self.prompt_run: bool = prompt_run + # Total number of batched tokens. + self.num_batched_tokens: int = num_batched_tokens + # Blocks to swap in. Dict of CPU -> GPU block number. + self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in + # Blocks to swap out. Dict of GPU -> CPU block number. + self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out + # Blocks to copy. Source to a list of dest blocks. + self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy + # Sequence groups that are going to be ignored. + self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups + # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) - self.ignored_seq_groups = ignored_seq_groups - self.num_loras = len(self.lora_requests) + self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: self._sort_by_lora_ids() @@ -59,13 +97,13 @@ def is_empty(self) -> bool: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self) -> bool: - self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, - key=lambda g: - (g.lora_int_id, g.request_id)) + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: - return {g.lora_request for g in self.scheduled_seq_groups} + return {g.seq_group.lora_request for g in self.scheduled_seq_groups} class Scheduler: @@ -198,11 +236,13 @@ def _schedule(self) -> SchedulerOutputs: assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_prompt_tokens = waiting_seqs[0].get_len() - if num_prompt_tokens > self.prompt_limit: + # get_len includes output tokens if the request has been + # preempted. + num_prefill_tokens = waiting_seqs[0].get_len() + if num_prefill_tokens > self.prompt_limit: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds limit of {self.prompt_limit}") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -215,8 +255,8 @@ def _schedule(self) -> SchedulerOutputs: break elif can_allocate == AllocStatus.NEVER: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds the capacity of block_manager") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -235,7 +275,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - num_batched_tokens += num_prompt_tokens + num_batched_tokens += num_prefill_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -253,8 +293,10 @@ def _schedule(self) -> SchedulerOutputs: self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs - scheduled.append(seq_group) - + scheduled.append( + ScheduledSequenceGroup( + seq_group=seq_group, + token_chunk_size=num_prefill_tokens)) self.waiting.extendleft(leftover_waiting_sequences) if scheduled or ignored_seq_groups: @@ -352,7 +394,11 @@ def _schedule(self) -> SchedulerOutputs: for seq_group in self.running) scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=self.running, + scheduled_seq_groups=[ + ScheduledSequenceGroup(seq_group=running_group, + token_chunk_size=1) + for running_group in self.running + ], prompt_run=False, num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, @@ -371,10 +417,14 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -393,6 +443,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, state=seq_group.state, @@ -409,8 +460,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution # will crash the vLLM instance / will not retry. - for seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed(seq_group) + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group) return seq_group_metadata_list, scheduler_outputs @@ -418,6 +470,7 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: @@ -480,7 +533,8 @@ def _preempt_by_recompute( assert len(seqs) == 1 for seq in seqs: seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + self.free_seq(seq) + seq.reset_state_for_recompute() # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 09f90d10ab2e9..83ef7ca182c3d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -62,6 +62,7 @@ class EngineArgs: image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: bool = False def __post_init__(self): if self.tokenizer is None: @@ -356,6 +357,12 @@ def add_cli_args( default=EngineArgs.scheduler_delay_factor, help='Apply a delay (of delay factor multiplied by previous' 'prompt latency) before scheduling next prompt.') + parser.add_argument( + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') return parser @classmethod @@ -394,11 +401,14 @@ def create_engine_configs( self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.use_v2_block_manager, - self.scheduler_delay_factor) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.use_v2_block_manager, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 649cd0400de6a..a977a23d54fe0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -553,7 +553,10 @@ def _process_model_outputs( # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.update_num_computed_tokens(token_chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -561,7 +564,8 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in scheduled_seq_groups: + for scheduled_seq_group in scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) @@ -676,17 +680,20 @@ def _get_stats(self, # Number of Tokens. if prompt_run: num_prompt_tokens = sum( - len(seq_group.prompt_token_ids) - for seq_group in scheduler_outputs.scheduled_seq_groups) + len(scheduled_seq_group.seq_group.prompt_token_ids) + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) num_generation_tokens = sum( - seq_group.num_seqs() - for seq_group in scheduler_outputs.scheduled_seq_groups) + scheduled_seq_group.seq_group.num_seqs() + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) else: num_generation_tokens = scheduler_outputs.num_batched_tokens # Latency Timings. time_last_iters = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group # Time since last token. # (n.b. updates seq_group.metrics.last_token_time) time_last_iters.append(seq_group.get_last_latency(now)) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 9eae7a7df1367..f64c411cc6cb0 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -20,6 +20,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4ac72bb0de34c..8f80c20738bba 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,8 +188,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - # FIXME(woosuk): We are not properly initializing pynccl when - # we have multiple nodes. self._run_workers("init_device") self._run_workers( "load_model", diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000000000..9262a74a4a0e1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..0ecf814a28a94 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..3793fcafee60b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000000000..f4c0f8417b384 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000000000..b41f9d443e506 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 79ddb4736e25c..b5c7e44de619c 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -47,6 +47,7 @@ "PhiForCausalLM": ("phi", "PhiForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py new file mode 100644 index 0000000000000..6b4a74198fd52 --- /dev/null +++ b/vllm/model_executor/models/qwen2_moe.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + + +class Qwen2MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2MoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.num_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + Qwen2MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + linear_method=None) + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False, + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Qwen2MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen2MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + if (config.num_experts is not None + and (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen2MoeSparseMoeBlock(config=config, + linear_method=linear_method) + else: + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2MoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + Qwen2MoeDecoderLayer(config, + layer_idx, + linear_method=linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2MoeForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = Qwen2MoeModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_expert." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_expert." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 0eb75e02d62cf..968dd7e17d021 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -202,6 +202,7 @@ def __init__( init_method=None, timeout=datetime.timedelta(seconds=10), world_size: int = -1, + local_rank: int = -1, rank: int = -1, store=None, group_name: str = "", @@ -219,25 +220,22 @@ def __init__( store=store, group_name=group_name, pg_options=pg_options) - self.world_size = dist.get_world_size() - self.rank = dist.get_rank() - torch.cuda.set_device(self.rank) - if self.rank == 0: + torch.cuda.set_device(local_rank) + if rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( - self.rank) + tensor = torch.ByteTensor(list( + self.unique_id.internal)).cuda(local_rank) dist.broadcast(tensor, src=0) byte_list = tensor.cpu().tolist() - self.unique_id = NcclUniqueId() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank) + result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, + self.unique_id, rank) assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}") + self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") def all_reduce(self, tensor: torch.Tensor, diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py index a12d620d7a24c..5b5eebbde44f6 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -10,7 +10,6 @@ try: from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, ncclGetVersion) - logger.info(f"vLLM is using nccl=={ncclGetVersion()}") except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs @@ -36,11 +35,14 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, rank: int, init_method: str) -> None: +def init_process_group(world_size: int, local_rank: int, rank: int, + init_method: str) -> None: assert not is_initialized() global comm + logger.info(f"vLLM is using nccl=={ncclGetVersion()}") comm = NCCLCommunicator(init_method=init_method, world_size=world_size, + local_rank=local_rank, rank=rank) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8292e207b8078..a40f38f76d1c4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,6 +113,8 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 + # The number of tokens that are computed (that run against the model). + self._num_computed_tokens = 0 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) @@ -130,6 +132,28 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_num_computed_tokens(self) -> int: + """Return the number of prefill tokens that are already computed.""" + return self._num_computed_tokens + + def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: + """Update number of tokens computed so far.""" + self._num_computed_tokens += num_new_computed_tokens + + def reset_num_computed_tokens(self) -> None: + """Reset the number of computed tokens from this sequence. It is + supposed to be called when a sequence needs to be started from + the beginning again (e.g., sequence is preempted). + """ + self._num_computed_tokens = 0 + + def get_num_uncomputed_tokens(self) -> int: + """Return the number of prefil tokens that are not computed.""" + # we use `get_len()` which includes prompt_len + output_len instead + # of prompt_len here. This is because during recompute we need to + # prefill for both prompt and output. + return self.get_len() - self.get_num_computed_tokens() + def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] @@ -208,6 +232,10 @@ def hash_of_block(self, logical_idx: int) -> int: def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size + def reset_state_for_recompute(self): + """Reset the sequence states for recomputation.""" + self.data.reset_num_computed_tokens() + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -430,6 +458,18 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + for seq in self.seqs_dict.values(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) + + def get_num_uncomputed_tokens(self) -> int: + # All sequences in the group should have the same prompt, so the + # number of unfinished prefill tokens are the same across all + # sequences. + return list( + self.seqs_dict.values())[0].data.get_num_uncomputed_tokens() + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) @@ -473,6 +513,8 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + token_chunk_size: The number of tokens to be processed. None if + chunking is not required. state: Internal state tied to this sequence group. lora_request: LoRA request. multi_modal_data: Multi modal data. @@ -485,6 +527,7 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, @@ -499,11 +542,23 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self._token_chunk_size = token_chunk_size + + if self._token_chunk_size is None: + if is_prompt: + self._token_chunk_size = list(seq_data.values())[0].get_len() + else: + self._token_chunk_size = 1 @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def token_chunk_size(self) -> int: + """Return the number of tokens to be processed (chunk size).""" + return self._token_chunk_size + class SequenceOutput: """The model output associated with a sequence. diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 5b2eeafad197e..735cc0037ba5f 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -8,6 +8,7 @@ def init_test_distributed_environment( pipeline_parallel_size: int, tensor_parallel_size: int, + local_rank: int, rank: int, distributed_init_port: str, ) -> None: @@ -16,7 +17,10 @@ def init_test_distributed_environment( worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, rank, distributed_init_method=distributed_init_method) + parallel_config, + local_rank, + rank, + distributed_init_method=distributed_init_method) def multi_process_tensor_parallel( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f0c98700ab749..31fa52476af1d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -150,39 +150,58 @@ def _prepare_prompt( subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and computed_block_nums is not None): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + prefill_end = min(seq_data.get_len(), + computed_len + token_chunk_size) + # TODO(sang): Rename it after chunked prefill is introduced. + prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_len = len(prompt_tokens) + # Right now, the prefill_end is always same as the length of + # sequence. However, once chunked prefill is introduced, this + # assumption can be changed. + assert prefill_end == seq_data.get_len() prompt_lens.append(prompt_len) - computed_len = 0 # NOTE: This only works for oooooooxxx style attention. - computed_block_nums = seq_group_metadata.computed_block_nums if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - context_len = computed_len else: prefix_block_tables.append([]) - context_len = 0 + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert computed_len == 0 + # actual prompt lens - context_lens.append(context_len) + context_lens.append(computed_len) subquery_lens.append(prompt_len - computed_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend( - list(range(computed_len, computed_len + len(prompt_tokens)))) + input_positions.extend(list(range(computed_len, prefill_end))) lora_id = seq_group_metadata.lora_int_id @@ -218,7 +237,8 @@ def _prepare_prompt( "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - for i in range(computed_len, prompt_len): + + for i in range(computed_len, prefill_end): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -331,6 +351,7 @@ def _prepare_decode( for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6459c0cda669a..4ffe780400101 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -97,8 +97,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.local_rank, + self.rank, self.distributed_init_method) # Set random seed. set_random_seed(self.model_config.seed) @@ -249,6 +249,7 @@ def get_cache_block_size_bytes(self, block_size: int, def init_distributed_environment( parallel_config: ParallelConfig, + local_rank: int, rank: int, distributed_init_method: Optional[str] = None, ) -> None: @@ -282,9 +283,9 @@ def init_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - # TODO(woosuk): Support multi-node connection. pynccl_utils.init_process_group( world_size=parallel_config.world_size, + local_rank=local_rank, rank=rank, init_method=distributed_init_method, )