diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py index 7f9672e9e6be6..c77d2bbb83972 100644 --- a/examples/offline_inference_neuron.py +++ b/examples/offline_inference_neuron.py @@ -1,8 +1,8 @@ import os, torch os.environ['NEURONX_DUMP_TO'] = os.path.join(os.getcwd(),"_compile_cache") # os.environ["NEURON_CC_FLAGS"]= " -O1 --internal-enable-dge-levels=vector_dynamic_offsets " -os.environ["NEURON_CC_FLAGS"]= " -O3 --internal-enable-dge-levels=vector_dynamic_offsets --disable-internal-io-dge" -# os.environ["NEURON_CC_FLAGS"]= " -O1 " +# os.environ["NEURON_CC_FLAGS"]= " -O3 --internal-enable-dge-levels=vector_dynamic_offsets --disable-internal-io-dge" +os.environ["NEURON_CC_FLAGS"]= " -O1 " # --internal-compiler-debug-mode=penguin --tensorizer-options='--enable-dge-on-indirect-dma' " # os.environ["NEURON_CC_FLAGS"] += " --tensorizer-options='--dump-after=All' " # os.environ["NEURON_CC_FLAGS"]= " --tensorizer-options='--enable-dge-on-indirect-dma' " diff --git a/vllm/attention/backends/neuron_flash_attn.py b/vllm/attention/backends/neuron_flash_attn.py index e4ec43f9f6a76..1434c1f2dae1a 100644 --- a/vllm/attention/backends/neuron_flash_attn.py +++ b/vllm/attention/backends/neuron_flash_attn.py @@ -7,8 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -50,17 +49,16 @@ def copy_blocks( @dataclass -class NeuronFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata, - AttentionMetadataPerStage): +class NeuronFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for NeuronFlashAttentionBackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] + seq_lens: Optional[List[int]] # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # prompt_lens_tensor: Optional[torch.Tensor] def __post_init__(self): # Set during the execution of the first attention op. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 81ca7bd289a80..adeff75fd51a3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -5,7 +5,7 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -15,7 +15,13 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +# from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -62,12 +68,14 @@ def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, + cache_config: CacheConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.cache_config = cache_config self.sliding_window = None if model_config is not None and model_config.get_sliding_window(): @@ -78,6 +86,21 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + self.kv_cache_dtype = "bfloat16" + self.block_size = cache_config.block_size + + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + self.attn_backend = get_attn_backend( + num_attn_heads, + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype if model_config is not None else None, + kv_cache_dtype=self.kv_cache_dtype, + block_size=self.block_size, + ) + # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config) @@ -100,7 +123,7 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], BatchedTensorInputs]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -210,18 +233,15 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, + slot_mapping=slot_mapping, + seq_lens=prompt_lens, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, num_prefills=len(prompt_lens), num_prefill_tokens=sum(prompt_lens), num_decode_tokens=0, - prefill_metadata=None, - decode_metadata=None, - max_context_len=None, - context_lens=None, block_tables=torch.tensor([]), - slot_mapping=slot_mapping, - kv_cache_dtype="bfloat16", # "auto", # "auto" means use model weight data type + # kv_cache_dtype=self.kv_cache_dtype, # "auto", # "auto" means use model weight data type ) input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device).unsqueeze(0) input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.device).unsqueeze(0) @@ -307,7 +327,7 @@ def _prepare_decode( decode_metadata=None, context_lens=context_lens, block_tables=block_tables, - kv_cache_dtype="auto", + # kv_cache_dtype="auto", ) return ( input_tokens, @@ -332,7 +352,6 @@ def prepare_model_input( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index d8940f14793bb..29e1259f33e7f 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -37,7 +37,7 @@ def __init__( init_cached_hf_modules() self.model_runner: NeuronModelRunner = NeuronModelRunner( - model_config, parallel_config, scheduler_config, device_config) + model_config, parallel_config, cache_config, scheduler_config, device_config) self.is_driver_worker = True def init_device(self) -> None: