diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63904ea929870..d21b54b16db4b 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -50,20 +50,15 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, + AttentionMetadataPerStage): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool + slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] - prompt_lens_tensor: Optional[torch.Tensor] - - max_subquery_len: Optional[int] = None - max_prompt_len: Optional[int] = None - subquery_start_loc: Optional[torch.Tensor] = None - seq_start_loc: Optional[torch.Tensor] = None - use_cuda_graph: bool = False def __post_init__(self): # Set during the execution of the first attention op. @@ -111,7 +106,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[TorchSDPAMetadata], + attn_metadata: TorchSDPAMetadata, kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -140,51 +135,36 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - if (kv_cache is None or prefill_meta.block_tables.numel() == 0): + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if prefill_meta.attn_bias is None: + if attn_metadata.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - prefill_meta.prompt_lens) # type: ignore + attn_metadata.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - prefill_meta.prompt_lens, self.sliding_window, + attn_metadata.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(prefill_meta.prompt_lens) - prefill_meta.attn_bias = att_masks + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - out = torch.empty((num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(prefill_meta.prompt_lens, - prefill_meta.attn_bias): + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -194,32 +174,28 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - out[start:end, :, :] = sub_out + output[start:end, :, :] = sub_out start = end - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - if decode_meta := attn_metadata.decode_metadata: + else: # Decoding run. - out = PagedAttention.forward_decode( - decode_query, + output = PagedAttention.forward_decode( + query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) - assert out.shape == output[num_prefill_tokens:].shape - output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) @@ -241,7 +217,7 @@ def _make_alibi_bias( bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0ed..eda4e8989c163 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, assert lora_config is None, "cpu backend doesn't support LoRA" model_config = _verify_and_get_model_config(model_config) cache_config = _verify_and_get_cache_config(cache_config) + scheduler_config = _verify_and_get_scheduler_config(scheduler_config) self.model_config = model_config self.cache_config = cache_config @@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: _GB = 1 << 30 if config.enable_prefix_caching: diff --git a/vllm/utils.py b/vllm/utils.py index 8ab8927512cc9..fdb0a3768ab0d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..49e1ad5709f5d --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,408 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad, maybe_expand_dim + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + prompt_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + prompt_len = len(prompt_tokens) + + prompt_lens.append(prompt_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prompt_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + + for i in range(computed_len, prompt_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + prompt_lens=prompt_lens, + num_prefills=len(prompt_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + prefill_metadata=None, + decode_metadata=None, + max_context_len=None, + context_lens=None, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + prompt_lens, + ) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_context_len = max(context_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + max_context_len=max_context_len, + num_prefills=0, + prefill_metadata=None, + decode_metadata=None, + context_lens=context_lens, + block_tables=block_tables, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + subquery_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) + categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=self.device).manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) + categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long) + + categorized_sample_indices = { + t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 751384eb72af3..3989207e8dd83 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -12,25 +12,14 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend.