diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 4f69ebef662cb..9183077c28c91 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -141,7 +141,7 @@ def forward( attn_metadata.kv_cache_dtype) if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if (kv_cache is None or attn_metadata.block_tables is None): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, @@ -221,8 +221,8 @@ def _make_alibi_bias( bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) - bias.mul_(alibi_slopes[:, None, None]) + bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)\ + .mul(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) diff --git a/vllm/utils.py b/vllm/utils.py index 17b97f393ff21..fbdeb906f2631 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -370,7 +370,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..89f6da14f2e69 --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,405 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad, maybe_expand_dim + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + prompt_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + prompt_len = len(prompt_tokens) + + prompt_lens.append(prompt_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prompt_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + + for i in range(computed_len, prompt_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + prompt_lens=prompt_lens, + prompt_lens_tensor=None, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, + max_context_len=None, + context_lens=None, + block_tables=None, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + prompt_lens, + ) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_context_len = max(context_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + prompt_lens_tensor=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + subquery_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) + categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=self.device).manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) + categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long) + + categorized_sample_indices = { + t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 262ed9abd36b7..c84c3653040d4 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -9,28 +9,17 @@ ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( ensure_model_parallel_initialized) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend.