Skip to content

Commit

Permalink
working on paged attention support
Browse files Browse the repository at this point in the history
  • Loading branch information
liangfu committed Aug 26, 2024
1 parent 0d845a5 commit 6605960
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 22 deletions.
4 changes: 2 additions & 2 deletions examples/offline_inference_neuron.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os, torch
os.environ['NEURONX_DUMP_TO'] = os.path.join(os.getcwd(),"_compile_cache")
# os.environ["NEURON_CC_FLAGS"]= " -O1 --internal-enable-dge-levels=vector_dynamic_offsets "
os.environ["NEURON_CC_FLAGS"]= " -O3 --internal-enable-dge-levels=vector_dynamic_offsets --disable-internal-io-dge"
# os.environ["NEURON_CC_FLAGS"]= " -O1 "
# os.environ["NEURON_CC_FLAGS"]= " -O3 --internal-enable-dge-levels=vector_dynamic_offsets --disable-internal-io-dge"
os.environ["NEURON_CC_FLAGS"]= " -O1 "
# --internal-compiler-debug-mode=penguin --tensorizer-options='--enable-dge-on-indirect-dma' "
# os.environ["NEURON_CC_FLAGS"] += " --tensorizer-options='--dump-after=All' "
# os.environ["NEURON_CC_FLAGS"]= " --tensorizer-options='--enable-dge-on-indirect-dma' "
Expand Down
10 changes: 4 additions & 6 deletions vllm/attention/backends/neuron_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from torch.nn.functional import scaled_dot_product_attention

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)

Expand Down Expand Up @@ -50,17 +49,16 @@ def copy_blocks(


@dataclass
class NeuronFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadataPerStage):
class NeuronFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for NeuronFlashAttentionBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
seq_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# prompt_lens_tensor: Optional[torch.Tensor]

def __post_init__(self):
# Set during the execution of the first attention op.
Expand Down
45 changes: 32 additions & 13 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
Expand All @@ -15,7 +15,13 @@
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
# from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
Expand Down Expand Up @@ -62,12 +68,14 @@ def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
cache_config: CacheConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.cache_config = cache_config

self.sliding_window = None
if model_config is not None and model_config.get_sliding_window():
Expand All @@ -78,6 +86,21 @@ def __init__(
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()

self.kv_cache_dtype = "bfloat16"
self.block_size = cache_config.block_size

num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
self.attn_backend = get_attn_backend(
num_attn_heads,
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype if model_config is not None else None,
kv_cache_dtype=self.kv_cache_dtype,
block_size=self.block_size,
)

# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
Expand All @@ -100,7 +123,7 @@ def set_block_size(self, block_size: int) -> None:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], BatchedTensorInputs]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
Expand Down Expand Up @@ -210,18 +233,15 @@ def _prepare_prompt(

attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
slot_mapping=slot_mapping,
seq_lens=prompt_lens,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(prompt_lens),
num_prefill_tokens=sum(prompt_lens),
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype="bfloat16", # "auto", # "auto" means use model weight data type
# kv_cache_dtype=self.kv_cache_dtype, # "auto", # "auto" means use model weight data type
)
input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device).unsqueeze(0)
input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.device).unsqueeze(0)
Expand Down Expand Up @@ -307,7 +327,7 @@ def _prepare_decode(
decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables,
kv_cache_dtype="auto",
# kv_cache_dtype="auto",
)
return (
input_tokens,
Expand All @@ -332,7 +352,6 @@ def prepare_model_input(
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
init_cached_hf_modules()

self.model_runner: NeuronModelRunner = NeuronModelRunner(
model_config, parallel_config, scheduler_config, device_config)
model_config, parallel_config, cache_config, scheduler_config, device_config)
self.is_driver_worker = True

def init_device(self) -> None:
Expand Down

0 comments on commit 6605960

Please sign in to comment.