From 925f3332cac488e5ad2dbc8f5c6d5f42d2556816 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 24 Mar 2024 21:39:33 -0700 Subject: [PATCH] [Core] Refactor Attention Take 2 (#3462) --- tests/kernels/test_prefix_prefill.py | 3 +- tests/samplers/test_beam_search.py | 7 + tests/worker/test_model_runner.py | 60 ++--- vllm/attention/__init__.py | 10 + .../attention/backends/__init__.py | 0 vllm/attention/backends/abstract.py | 85 +++++++ vllm/attention/backends/flash_attn.py | 238 ++++++++++++++++++ .../attention/backends/xformers.py | 232 +++++++++++++---- vllm/attention/layer.py | 46 ++++ .../layers => }/attention/ops/__init__.py | 0 vllm/attention/ops/paged_attn.py | 217 ++++++++++++++++ .../attention/ops/prefix_prefill.py | 0 vllm/attention/selector.py | 44 ++++ vllm/model_executor/__init__.py | 2 - vllm/model_executor/input_metadata.py | 99 -------- .../layers/attention/__init__.py | 5 - .../layers/attention/attention.py | 85 ------- .../layers/attention/backends/flash_attn.py | 139 ---------- .../layers/attention/ops/paged_attn.py | 139 ---------- vllm/model_executor/models/baichuan.py | 30 +-- vllm/model_executor/models/bloom.py | 32 ++- vllm/model_executor/models/chatglm.py | 41 ++- vllm/model_executor/models/deepseek.py | 32 ++- vllm/model_executor/models/falcon.py | 31 ++- vllm/model_executor/models/gemma.py | 30 +-- vllm/model_executor/models/gpt2.py | 33 ++- vllm/model_executor/models/gpt_bigcode.py | 33 ++- vllm/model_executor/models/gpt_j.py | 32 ++- vllm/model_executor/models/gpt_neox.py | 32 ++- vllm/model_executor/models/internlm2.py | 30 +-- vllm/model_executor/models/jais.py | 35 ++- vllm/model_executor/models/llama.py | 30 +-- vllm/model_executor/models/mixtral.py | 32 ++- vllm/model_executor/models/mixtral_quant.py | 32 ++- vllm/model_executor/models/mpt.py | 32 ++- vllm/model_executor/models/olmo.py | 30 +-- vllm/model_executor/models/opt.py | 39 ++- vllm/model_executor/models/orion.py | 30 +-- vllm/model_executor/models/phi.py | 32 ++- vllm/model_executor/models/qwen.py | 31 +-- vllm/model_executor/models/qwen2.py | 30 +-- vllm/model_executor/models/stablelm.py | 30 +-- vllm/model_executor/models/starcoder2.py | 32 ++- vllm/sequence.py | 2 +- vllm/worker/cache_engine.py | 115 +++------ vllm/worker/model_runner.py | 85 ++++--- vllm/worker/worker.py | 3 + 47 files changed, 1269 insertions(+), 1118 deletions(-) create mode 100644 vllm/attention/__init__.py rename vllm/{model_executor/layers => }/attention/backends/__init__.py (100%) create mode 100644 vllm/attention/backends/abstract.py create mode 100644 vllm/attention/backends/flash_attn.py rename vllm/{model_executor/layers => }/attention/backends/xformers.py (55%) create mode 100644 vllm/attention/layer.py rename vllm/{model_executor/layers => }/attention/ops/__init__.py (100%) create mode 100644 vllm/attention/ops/paged_attn.py rename vllm/{model_executor/layers => }/attention/ops/prefix_prefill.py (100%) create mode 100644 vllm/attention/selector.py delete mode 100644 vllm/model_executor/input_metadata.py delete mode 100644 vllm/model_executor/layers/attention/__init__.py delete mode 100644 vllm/model_executor/layers/attention/attention.py delete mode 100644 vllm/model_executor/layers/attention/backends/flash_attn.py delete mode 100644 vllm/model_executor/layers/attention/ops/paged_attn.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 2b35335a9c92b..5a09095e76688 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,8 +3,7 @@ import time import torch -from vllm.model_executor.layers.attention.ops.prefix_prefill import ( - context_attention_fwd) +from vllm.attention.ops.prefix_prefill import context_attention_fwd from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 9398aeb2c214c..15fef106f1f18 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -2,7 +2,10 @@ Run `pytest tests/samplers/test_beam_search.py --forked`. """ +import gc + import pytest +import torch # FIXME(zhuohan): The test can not pass if we: # 1. Increase max_tokens to 256. @@ -36,6 +39,10 @@ def test_beam_search_single_input( vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) del vllm_model + # NOTE(woosuk): For some reason, the following GC is required to avoid + # GPU OOM errors in the following tests using `vllm_runner`. + gc.collect() + torch.cuda.empty_cache() for i in range(len(example_prompts)): hf_output_ids, _ = hf_outputs[i] diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 01066ef796d67..12e3c8eff2ce8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += prompt_len - (input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens # Verify input metadata is correct for prompts. device = model_runner.device - assert input_metadata.is_prompt is True - assert torch.allclose(input_metadata.prompt_lens_tensor, + assert attn_metadata.is_prompt is True + assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) - assert input_metadata.prompt_lens == prompt_lens - assert input_metadata.num_prompt_tokens == sum(prompt_lens) - assert input_metadata.num_generation_tokens == 0 - assert input_metadata.max_seq_len == max(prompt_lens) + assert attn_metadata.prompt_lens == prompt_lens + assert attn_metadata.num_prompt_tokens == sum(prompt_lens) + assert attn_metadata.num_generation_tokens == 0 + assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. start_idx = 0 @@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size): start_idx += prompt_len start_loc.append(start_idx) assert torch.allclose( - input_metadata.subquery_start_loc, + attn_metadata.subquery_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is @@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size): seq_start_loc.append(start_idx) assert torch.allclose( - input_metadata.seq_start_loc, + attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert input_metadata.max_context_len is None + assert attn_metadata.max_context_len is None assert torch.allclose( - input_metadata.context_lens, - torch.zeros(input_metadata.context_lens.shape[0], + attn_metadata.context_lens, + torch.zeros(attn_metadata.context_lens.shape[0], dtype=torch.int, device=device)) expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], dtype=torch.int32, device=model_runner.device) - assert torch.allclose(input_metadata.block_tables, expected) + assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. - assert input_metadata.use_cuda_graph is False - assert input_metadata.kv_cache_dtype == "auto" + assert attn_metadata.use_cuda_graph is False + assert attn_metadata.kv_cache_dtype == "auto" assert input_tokens.shape == (sum(prompt_lens), ) assert input_positions.shape == (sum(prompt_lens), ) @@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size): block_tables={0: [1]}, )) - input_tokens, input_positions, input_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _ = ( model_runner._prepare_decode(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert input_metadata.is_prompt is False - assert input_metadata.prompt_lens is None - assert input_metadata.num_prompt_tokens == 0 - assert input_metadata.num_generation_tokens == expected_bs - assert input_metadata.max_seq_len is None - assert input_metadata.subquery_start_loc is None - assert input_metadata.seq_start_loc is None - assert input_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.is_prompt is False + assert attn_metadata.prompt_lens is None + assert attn_metadata.num_prompt_tokens == 0 + assert attn_metadata.num_generation_tokens == expected_bs + assert attn_metadata.max_prompt_len is None + assert attn_metadata.subquery_start_loc is None + assert attn_metadata.seq_start_loc is None + assert attn_metadata.max_context_len == max(prompt_lens) assert torch.allclose( - input_metadata.context_lens[:len(prompt_lens)], + attn_metadata.context_lens[:len(prompt_lens)], torch.tensor(prompt_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. - assert input_metadata.block_tables.shape[0] == len(input_tokens) + assert attn_metadata.block_tables.shape[0] == len(input_tokens) # Block table's second dim correspondsd to each token's block number. # It is padded up to - assert input_metadata.block_tables.shape[1] == ( + assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. - assert input_metadata.use_cuda_graph is True - assert input_metadata.kv_cache_dtype == "auto" + assert attn_metadata.use_cuda_graph is True + assert attn_metadata.kv_cache_dtype == "auto" assert input_tokens.shape == (expected_bs, ) assert input_positions.shape == (expected_bs, ) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py new file mode 100644 index 0000000000000..e8b9b95dc4234 --- /dev/null +++ b/vllm/attention/__init__.py @@ -0,0 +1,10 @@ +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "AttentionBackend", + "AttentionMetadata", + "Attention", + "get_attn_backend", +] diff --git a/vllm/model_executor/layers/attention/backends/__init__.py b/vllm/attention/backends/__init__.py similarity index 100% rename from vllm/model_executor/layers/attention/backends/__init__.py rename to vllm/attention/backends/__init__.py diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py new file mode 100644 index 0000000000000..a7e0ab92c7668 --- /dev/null +++ b/vllm/attention/backends/abstract.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def make_metadata(*args, **kwargs) -> "AttentionMetadata": + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + + def asdict_zerocopy(self) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) + } + + +class AttentionImpl(ABC): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py new file mode 100644 index 0000000000000..ac33a917bb0ad --- /dev/null +++ b/vllm/attention/backends/flash_attn.py @@ -0,0 +1,238 @@ +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +from flash_attn import flash_attn_varlen_func +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata + + +class FlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": + return FlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/attention/backends/xformers.py similarity index 55% rename from vllm/model_executor/layers/attention/backends/xformers.py rename to vllm/attention/backends/xformers.py index f0ef9fac9aaa4..b7eff2b598e1a 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,19 +1,127 @@ """Attention layer with xFormers and PagedAttention.""" import importlib -from typing import List, Optional +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.paged_attn import ( - PagedAttentionImpl) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata +from vllm.logger import init_logger from vllm.utils import is_hip +logger = init_logger(__name__) -class XFormersBackend: + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "XFormersMetadata": + return XFormersMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # FIXME: It is for flash attn. + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + + +class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prompt_tokens --------------->| @@ -50,22 +158,25 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - self.use_ref_attention = _check_use_ref_attention() + # AMD Radeon 7900 series (gfx1100) currently does not support xFormers + # nor FlashAttention. As a temporary workaround, we use naive PyTorch + # implementation of attention. + self.use_naive_attention = _check_use_naive_attention() def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - input_metadata: InputMetadata, + kv_cache: Optional[torch.Tensor], + attn_metadata: XFormersMetadata, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -73,11 +184,8 @@ def forward( query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for the inputs. + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ @@ -86,20 +194,24 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - # Reshape the keys and values and store them in the cache. - # If key_cache and value_cache are not provided, the new key and value - # vectors will not be cached. This happens during the initial memory - # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) - if input_metadata.is_prompt: + if attn_metadata.is_prompt: # Prompt run. - # key_cache and value_cache are None when it is a profiling run. - # block tables are empty if the prompt has never been computed. - if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -118,13 +230,12 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - if self.use_ref_attention: - print("ref attention used.") + if self.use_naive_attention: output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(input_metadata.prompt_lens): + for _, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len - out = _ref_masked_attention( + out = _naive_masked_attention( query[None, start:end], key[None, start:end], value[None, start:end], @@ -143,26 +254,33 @@ def forward( # Use reshape instead. return output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformer_forward( - query, key, value, input_metadata) + output = self._run_memory_efficient_xformers_forward( + query, key, value, attn_metadata) else: # prefix-enabled attention - output = PagedAttentionImpl.forward_prefix( + output = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - input_metadata, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, self.alibi_slopes, ) else: # Decoding run. - output = PagedAttentionImpl.forward_decode( + output = PagedAttention.forward_decode( query, key_cache, value_cache, - input_metadata, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, @@ -171,12 +289,12 @@ def forward( # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) - def _run_memory_efficient_xformer_forward( + def _run_memory_efficient_xformers_forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - input_metadata: InputMetadata, + attn_metadata: XFormersMetadata, ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -186,23 +304,23 @@ def _run_memory_efficient_xformer_forward( query: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size] - input_metadata: metadata for paged attention. + attn_metadata: Metadata for attention. """ # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: + if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - input_metadata.prompt_lens) + attn_metadata.prompt_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) - input_metadata.attn_bias = [attn_bias] + attn_metadata.attn_bias = [attn_bias] else: - input_metadata.attn_bias = _make_alibi_bias( + attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - input_metadata) + attn_metadata.prompt_lens) op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( is_hip()) else None @@ -217,7 +335,7 @@ def _run_memory_efficient_xformer_forward( query, key, value, - attn_bias=input_metadata.attn_bias[0], + attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale, op=op) @@ -230,13 +348,13 @@ def _run_memory_efficient_xformer_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(query) start = 0 - for i, prompt_len in enumerate(input_metadata.prompt_lens): + for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], value[None, start:end], - attn_bias=input_metadata.attn_bias[i], + attn_bias=attn_metadata.attn_bias[i], p=0.0, scale=self.scale, op=op) @@ -250,10 +368,10 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - input_metadata: InputMetadata, + prompt_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in input_metadata.prompt_lens: + for prompt_len in prompt_lens: bias = torch.arange(prompt_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` @@ -282,15 +400,19 @@ def _make_alibi_bias( return attn_biases -def _check_use_ref_attention() -> bool: +def _check_use_naive_attention() -> bool: if not is_hip(): return False # For ROCm, check whether flash attention is installed or not. - # if not, use_ref_attention needs to be True - return importlib.util.find_spec("flash_attn") is None + has_flash_attn = importlib.util.find_spec("flash_attn") is None + if not has_flash_attn: + logger.warning("flash_attn is not installed. Using naive attention. " + "This will take significantly more GPU memory.") + return True + return False -def _ref_masked_attention( +def _naive_masked_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py new file mode 100644 index 0000000000000..2e0aa18e52427 --- /dev/null +++ b/vllm/attention/layer.py @@ -0,0 +1,46 @@ +"""Attention layer.""" +from typing import List, Optional + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.selector import get_attn_backend + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.backend = get_attn_backend(torch.get_default_dtype()) + impl_cls = self.backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return self.impl.forward(query, key, value, kv_cache, attn_metadata) diff --git a/vllm/model_executor/layers/attention/ops/__init__.py b/vllm/attention/ops/__init__.py similarity index 100% rename from vllm/model_executor/layers/attention/ops/__init__.py rename to vllm/attention/ops/__init__.py diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py new file mode 100644 index 0000000000000..b20711eb95e59 --- /dev/null +++ b/vllm/attention/ops/paged_attn.py @@ -0,0 +1,217 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm._C import cache_ops +from vllm._C import ops +from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. + context_lens: Optional[torch.Tensor] + # Maximum context length in the batch. + max_context_len: Optional[int] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + kv_cache_dtype: str + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + ) -> None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + output = torch.empty_like(query) + + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_context_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + prompt_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_subquery_len: int, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + block_tables, + # subquery_start_loc is (batch_size + 1,) + subquery_start_loc[:-1], + prompt_lens_tensor, + context_lens, + max_subquery_len, + alibi_slopes, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py similarity index 100% rename from vllm/model_executor/layers/attention/ops/prefix_prefill.py rename to vllm/attention/ops/prefix_prefill.py diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py new file mode 100644 index 0000000000000..42b05ee320314 --- /dev/null +++ b/vllm/attention/selector.py @@ -0,0 +1,44 @@ +from functools import lru_cache + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + + +@lru_cache(maxsize=None) +def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: + if _can_use_flash_attn(dtype): + logger.info("Using FlashAttention backend.") + from vllm.attention.backends.flash_attn import FlashAttentionBackend # noqa: F401 + return FlashAttentionBackend + else: + logger.info("Using XFormers backend.") + from vllm.attention.backends.xformers import XFormersBackend # noqa: F401 + return XFormersBackend + + +def _can_use_flash_attn(dtype: torch.dtype) -> bool: + if is_hip(): + # AMD GPUs. + logger.info("Cannot use FlashAttention backend for AMD GPUs.") + return False + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info("Cannot use FlashAttention backend for Volta and Turing " + "GPUs.") + return False + if dtype not in (torch.float16, torch.bfloat16): + logger.info("Cannot use FlashAttention backend for dtype other than " + "torch.float16 or torch.bfloat16.") + return False + + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info("flash_attn is not found.") + return False + return True diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 5f3c78360e2d7..fb98f4a6b46f4 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,9 +1,7 @@ -from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ - "InputMetadata", "SamplingMetadata", "set_random_seed", ] diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py deleted file mode 100644 index 8fdac06c82dd7..0000000000000 --- a/vllm/model_executor/input_metadata.py +++ /dev/null @@ -1,99 +0,0 @@ -from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, Optional, List, Any, Dict - -import torch -if TYPE_CHECKING: - from xformers.ops.fmha.attn_bias import AttentionBias - - -@dataclass -class InputMetadata: - """Metadata for input sequences. Used in PagedAttention. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - """ - Definition of context_len, subquery_len, and seqlen. - |---------- N-1 iteration --------| - |---------------- N iteration ---------------------| - |- tokenA -|......................|-- newTokens ---| - |---------- context_len ----------| - |-------------------- seqlen ----------------------| - |- subquery_len -| - - WARNING: context_len has different definition depending on if it is - prefill vs decoding. When it is prefill, it doesn't include new - tokens. When it is for decoding, it includes a new token. - """ - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum context length in the batch. - max_context_len: Optional[int] - # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - use_cuda_graph: bool - kv_cache_dtype: str - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List["AttentionBias"]] = None - - # Cuda graph is only used for decoding now. - if self.use_cuda_graph: - assert self.num_prompt_tokens == 0 - - def asdict_zerocopy(self) -> Dict[str, Any]: - """Similar to dataclasses.asdict, but avoids deepcopying.""" - # Note that if we add dataclasses as fields, they will need - # similar handling. - return { - field.name: getattr(self, field.name) - for field in fields(self) - } diff --git a/vllm/model_executor/layers/attention/__init__.py b/vllm/model_executor/layers/attention/__init__.py deleted file mode 100644 index 1c42a3d28f976..0000000000000 --- a/vllm/model_executor/layers/attention/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from vllm.model_executor.layers.attention.attention import Attention - -__all__ = [ - "Attention", -] diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py deleted file mode 100644 index ae598b029a007..0000000000000 --- a/vllm/model_executor/layers/attention/attention.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Attention layer.""" -from functools import lru_cache -from typing import List, Optional - -import torch -import torch.nn as nn - -from vllm.logger import init_logger -from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip - -logger = init_logger(__name__) - - -class Attention(nn.Module): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Output the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - super().__init__() - if _use_flash_attn(): - from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 - self.backend = FlashAttentionBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) - else: - from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501 - self.backend = XFormersBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - input_metadata: InputMetadata, - ) -> torch.Tensor: - return self.backend.forward(query, key, value, key_cache, value_cache, - input_metadata) - - -@lru_cache(maxsize=1) -def _use_flash_attn() -> bool: - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info("flash_attn is not found. Using xformers backend.") - return False - - if is_hip(): - # AMD GPUs. - return False - if torch.cuda.get_device_capability()[0] < 8: - # Volta and Turing NVIDIA GPUs. - logger.info("flash_attn is not supported on Turing or older GPUs. " - "Using xformers backend.") - return False - if torch.get_default_dtype() not in (torch.float16, torch.bfloat16): - logger.info( - "flash_attn only supports torch.float16 or torch.bfloat16. " - "Using xformers backend.") - return False - - logger.info("Using flash_attn backend.") - return True diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py deleted file mode 100644 index 9ce5851f3650d..0000000000000 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Attention layer with Flash and PagedAttention.""" -from typing import List, Optional - -from flash_attn import flash_attn_varlen_func -import torch - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.paged_attn import ( - PagedAttentionImpl) - - -class FlashAttentionBackend: - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - - self.sliding_window = ((self.sliding_window, self.sliding_window) if - self.sliding_window is not None else (-1, -1)) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - input_metadata: InputMetadata, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for the inputs. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - # Reshape the keys and values and store them in the cache. - # If key_cache and value_cache are not provided, the new key and value - # vectors will not be cached. This happens during the initial memory - # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) - - if input_metadata.is_prompt: - # Prompt run. - if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=input_metadata.seq_start_loc, - cu_seqlens_k=input_metadata.seq_start_loc, - max_seqlen_q=input_metadata.max_seq_len, - max_seqlen_k=input_metadata.max_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - else: - # prefix-enabled attention - output = PagedAttentionImpl.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - input_metadata, - self.alibi_slopes, - ) - else: - # Decoding run. - output = PagedAttentionImpl.forward_decode( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py deleted file mode 100644 index 3105ba37b9832..0000000000000 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import List, Optional - -import torch - -from vllm._C import cache_ops -from vllm._C import ops -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.prefix_prefill import ( - context_attention_fwd) - -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -class PagedAttentionImpl: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] - - @staticmethod - def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - ) -> None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) - - @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - output = torch.empty_like(query) - - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - output = torch.empty_like(query) - context_attention_fwd( - query, - key, - value, - output, - key_cache, - value_cache, - input_metadata.block_tables, - # subquery_start_loc is (batch_size + 1,) - input_metadata.subquery_start_loc[:-1], - input_metadata.prompt_lens_tensor, - input_metadata.context_lens, - input_metadata.max_subquery_len, - alibi_slopes, - ) - return output diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 968b9ebba87b2..2d5fcf7b9c54f 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -25,9 +25,8 @@ from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -45,8 +44,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -170,15 +167,14 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -217,8 +213,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -232,7 +228,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -267,8 +263,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -278,7 +274,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -303,11 +299,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 7cf4370236a8b..a9ff909090586 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,15 +17,14 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import BloomConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -41,8 +40,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -117,14 +114,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.dense(attn_output) return output @@ -181,8 +177,8 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) @@ -198,7 +194,7 @@ def forward( position_ids=position_ids, hidden_states=layernorm_output, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) @@ -245,8 +241,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) @@ -256,7 +252,7 @@ def forward( position_ids, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -281,11 +277,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 15e7de03b61f1..88a1c81008558 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,15 +2,14 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from torch.nn import LayerNorm -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -29,8 +28,6 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GLMAttention(nn.Module): @@ -99,20 +96,18 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - key_cache, value_cache = kv_cache context_layer = self.attn( q, k, v, - key_cache, - value_cache, - input_metadata, + kv_cache, + attn_metadata, ) attn_output, _ = self.dense(context_layer) return attn_output @@ -200,8 +195,8 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. @@ -211,7 +206,7 @@ def forward( hidden_states=layernorm_output, position_ids=position_ids, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Residual connection. @@ -264,8 +259,8 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: for i in range(self.num_layers): layer = self.layers[i] @@ -273,7 +268,7 @@ def forward( hidden_states=hidden_states, position_ids=position_ids, kv_cache=kv_caches[i], - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Final layer norm. if self.post_layer_norm: @@ -306,8 +301,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) @@ -316,7 +311,7 @@ def forward( hidden_states=inputs_embeds, position_ids=position_ids, kv_caches=kv_caches, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) return hidden_states @@ -340,11 +335,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 08c851f85c17b..c66f72db21e9e 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,15 +21,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -51,8 +50,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class DeepseekMLP(nn.Module): @@ -239,14 +236,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -294,8 +290,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -309,7 +305,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -346,15 +342,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], input_metadata, + kv_caches[i], attn_metadata, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -379,11 +375,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 0a01796a96416..543e87101f6ea 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -19,16 +19,15 @@ """PyTorch Falcon model.""" import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch import nn from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -48,7 +47,6 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -177,8 +175,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, bias = self.query_key_value(hidden_states) if bias is not None: @@ -186,8 +184,7 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_rotary: q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, bias = self.dense(attn_output) return attn_output, bias @@ -263,8 +260,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states @@ -279,7 +276,7 @@ def forward( positions=positions, hidden_states=attention_layernorm_out, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) if self.reduce_row_parallel_results and attention_bias is not None: attention_output += attention_bias @@ -343,8 +340,8 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) for i in range(len(self.h)): @@ -353,7 +350,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -378,14 +375,14 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, ) return hidden_states diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fa8ce60e74056..49a08a62b54ac 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -20,10 +20,9 @@ from torch import nn from transformers import GemmaConfig +from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import GeluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -41,8 +40,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GemmaMLP(nn.Module): @@ -133,14 +130,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -177,8 +173,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -192,7 +188,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -226,8 +222,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Normalize the embedding by sqrt(hidden_size) @@ -240,7 +236,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -290,11 +286,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index e75dda750cb26..3f816a9996be5 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,15 +17,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import GPT2Config -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -41,8 +40,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GPT2Attention(nn.Module): @@ -79,14 +76,12 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -144,15 +139,15 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # residual connection hidden_states = attn_output + residual @@ -190,8 +185,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) @@ -199,7 +194,7 @@ def forward( for i in range(len(self.h)): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -224,11 +219,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 65caabae60daa..07c647c2e1c41 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,15 +18,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -42,8 +41,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GPTBigCodeAttention(nn.Module): @@ -94,8 +91,8 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.split( @@ -105,9 +102,7 @@ def forward( ], dim=-1, ) - key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -165,15 +160,15 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # residual connection hidden_states = attn_output + residual @@ -211,8 +206,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) @@ -220,7 +215,7 @@ def forward( for i in range(len(self.h)): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -245,11 +240,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index c956a12f3e46e..ae5d480cf4bc4 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,15 +16,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import GPTJConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -41,8 +40,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GPTJAttention(nn.Module): @@ -93,14 +90,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -154,8 +150,8 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -163,7 +159,7 @@ def forward( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) mlp_output = self.mlp(hidden_states) hidden_states = attn_output + mlp_output + residual @@ -192,8 +188,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.h)): @@ -202,7 +198,7 @@ def forward( position_ids, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -232,11 +228,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index db2173936e7d9..e08adf06bf115 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,15 +16,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import GPTNeoXConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -41,8 +40,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class GPTNeoXAttention(nn.Module): @@ -94,14 +91,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.dense(attn_output) return output @@ -155,15 +151,15 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: attn_input = self.input_layernorm(hidden_states) attn_output = self.attention( position_ids=position_ids, hidden_states=attn_input, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) if self.use_parallel_residual: @@ -208,8 +204,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) for i in range(len(self.layers)): @@ -218,7 +214,7 @@ def forward( position_ids, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.final_layer_norm(hidden_states) return hidden_states @@ -246,11 +242,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 93026fc01f0f0..03b3271daa508 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -5,9 +5,8 @@ from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -25,8 +24,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class InternLM2MLP(nn.Module): @@ -124,14 +121,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.wqkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.wo(attn_output) return output @@ -172,8 +168,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -187,7 +183,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -221,8 +217,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.tok_embeddings(input_ids) residual = None @@ -232,7 +228,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -258,11 +254,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 74c8e7f963026..e3f3dce375046 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,14 +20,13 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from vllm.transformers_utils.configs import JAISConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearMethodBase, @@ -49,8 +48,6 @@ from vllm.sequence import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -KVCache = Tuple[torch.Tensor, torch.Tensor] - class SwiGLUActivation(nn.Module): @@ -122,14 +119,12 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -196,15 +191,15 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # residual connection hidden_states = attn_output + residual @@ -248,8 +243,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) if self.wpe is not None: @@ -262,7 +257,7 @@ def forward( for i in range(len(self.h)): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -293,11 +288,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -348,4 +343,4 @@ def load_weights( loaded_weight = loaded_weight.t() weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 757b75129845c..4d53548d5304d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,10 +27,9 @@ from torch import nn from transformers import LlamaConfig +from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -48,8 +47,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class LlamaMLP(nn.Module): @@ -150,14 +147,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -203,8 +199,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -218,7 +214,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -258,8 +254,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -269,7 +265,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -336,11 +332,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f0138b6f9b1db..f4dae20f9a228 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,15 +21,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import MixtralConfig +from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -51,8 +50,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert @@ -209,14 +206,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -254,8 +250,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -269,7 +265,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -309,15 +305,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], input_metadata, + kv_caches[i], attn_metadata, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -377,11 +373,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index b8d6b45a36dd6..15068efb3b0b7 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional, Tuple +from typing import List, Optional import numpy as np @@ -31,8 +31,7 @@ from torch import nn from transformers import MixtralConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, ReplicatedLinear, @@ -52,8 +51,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class MixtralMLP(nn.Module): @@ -227,14 +224,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -269,8 +265,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -284,7 +280,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -319,15 +315,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], input_metadata, + kv_caches[i], attn_metadata, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -352,11 +348,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 7a2568817858c..a39f94359a948 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,14 +1,13 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import List, Optional, Tuple +from typing import List, Optional import torch import torch.nn as nn -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -25,8 +24,6 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - def _get_alibi_slopes( total_num_heads: int, @@ -116,8 +113,8 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # unused. qkv, _ = self.Wqkv(hidden_states) @@ -127,8 +124,7 @@ def forward( if self.qk_ln: q = self.q_ln(q) k = self.k_ln(k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -184,15 +180,15 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: x = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=x, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) hidden_states = hidden_states + x x = self.norm_2(hidden_states) @@ -230,8 +226,8 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.blocks)): @@ -240,7 +236,7 @@ def forward( position_ids, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.norm_f(hidden_states) return hidden_states @@ -267,11 +263,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 19f2be6da8ed3..237f870dfe4a6 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -42,8 +42,7 @@ import torch.nn.functional as F from torch import nn -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearMethodBase, @@ -67,8 +66,6 @@ # this model must need this dependency from hf_olmo import OLMoConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - class SwiGLU(nn.Module): @@ -146,16 +143,15 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.attn_norm(hidden_states) qkv, _ = self.att_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.config.rope: q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.attn_out(attn_output) return output @@ -241,12 +237,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. og_x = hidden_states - x = self.attn(positions, hidden_states, kv_cache, input_metadata) + x = self.attn(positions, hidden_states, kv_cache, attn_metadata) x = x + og_x # MLP block. @@ -296,8 +292,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. @@ -313,7 +309,7 @@ def forward( positions, x, kv_caches[block_idx], - input_metadata, + attn_metadata, ) # Apply final layer norm. @@ -344,14 +340,14 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) return hidden_states diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a12f63b58f52b..c1ae1b2ae0f03 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,15 +17,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import OPTConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -42,8 +41,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -97,14 +94,12 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -152,8 +147,8 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -162,7 +157,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata) + attn_metadata=attn_metadata) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -241,8 +236,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) @@ -252,7 +247,7 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) @@ -275,10 +270,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: - return self.decoder(input_ids, positions, kv_caches, input_metadata) + return self.decoder(input_ids, positions, kv_caches, attn_metadata) class OPTForCausalLM(nn.Module): @@ -300,11 +295,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 86428e320e0f7..ea8119df664cc 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -10,9 +10,8 @@ from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -29,8 +28,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class OrionMLP(nn.Module): @@ -128,14 +125,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -178,8 +174,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -189,7 +185,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -227,8 +223,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -238,7 +234,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states = self.norm(hidden_states) @@ -264,11 +260,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index ef70c823dc905..1737e5efb6cb3 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,15 +35,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -60,8 +59,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class PhiAttention(nn.Module): @@ -115,14 +112,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.dense(attn_output) return output @@ -172,8 +168,8 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -181,7 +177,7 @@ def forward( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual @@ -209,8 +205,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for i in range(self.config.num_hidden_layers): @@ -219,7 +215,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.final_layernorm(hidden_states) @@ -248,11 +244,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 61ac2c6c605c6..bd7976dfc1d48 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -10,9 +10,8 @@ from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -30,8 +29,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class QWenMLP(nn.Module): @@ -111,15 +108,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.c_proj(attn_output) return output @@ -153,8 +148,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -167,7 +162,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -201,8 +196,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.wte(input_ids) residual = None @@ -212,7 +207,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -238,11 +233,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 49c2a8b732fed..fe34fe113866d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -28,9 +28,8 @@ from torch import nn from transformers import Qwen2Config -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -49,8 +48,6 @@ from vllm.sequence import SamplerOutput from vllm.config import LoRAConfig -KVCache = Tuple[torch.Tensor, torch.Tensor] - class Qwen2MLP(nn.Module): @@ -147,14 +144,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -197,8 +193,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -212,7 +208,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected @@ -248,8 +244,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -259,7 +255,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -315,11 +311,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 7624ca89ee670..7d64bcdf3f3ba 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -25,9 +25,8 @@ from torch import nn from transformers import PretrainedConfig -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -44,8 +43,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class StablelmMLP(nn.Module): @@ -134,14 +131,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -166,8 +162,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -176,7 +172,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -211,8 +207,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for i in range(len(self.layers)): @@ -221,7 +217,7 @@ def forward( positions, hidden_states, kv_caches[i], - input_metadata, + attn_metadata, ) hidden_states = self.norm(hidden_states) return hidden_states @@ -246,11 +242,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index e72c5cf1544f7..82e2cfa961db2 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,15 +18,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import List, Optional, Tuple +from typing import List, Optional import torch from torch import nn from transformers import Starcoder2Config -from vllm.model_executor.input_metadata import InputMetadata +from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -43,8 +42,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class Starcoder2Attention(nn.Module): @@ -111,14 +108,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -171,8 +167,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -181,7 +177,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -217,14 +213,14 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(positions, hidden_states, kv_caches[i], - input_metadata) + attn_metadata) hidden_states = self.norm(hidden_states) return hidden_states @@ -258,11 +254,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/sequence.py b/vllm/sequence.py index ff96dd306791c..af18eed959b1e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -431,7 +431,7 @@ def __repr__(self) -> str: class SequenceGroupMetadata: - """Metadata for a sequence group. Used to create `InputMetadata`. + """Metadata for a sequence group. Used to create `AttentionMetadata`. Args: request_id: The ID of the request. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 307b7b778cb3f..b403e28d8934d 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,16 +1,15 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List, Tuple +from typing import Dict, List import torch +from vllm.attention import get_attn_backend from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE logger = init_logger(__name__) -KVCache = Tuple[torch.Tensor, torch.Tensor] - class CacheEngine: """Manages the KV cache. @@ -43,95 +42,43 @@ def __init__( else: self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + # Get attention backend. + self.attn_backend = get_attn_backend(model_config.dtype) + # Initialize the cache. - self.gpu_cache = self.allocate_gpu_cache() - self.cpu_cache = self.allocate_cpu_cache() - - def get_key_block_shape(self) -> Tuple[int, int, int, int]: - element_size = torch.tensor([], dtype=self.dtype).element_size() - x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) - - def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) - - def allocate_gpu_cache(self) -> List[KVCache]: - gpu_cache: List[KVCache] = [] - key_block_shape = self.get_key_block_shape() - value_block_shape = self.get_value_block_shape() - for _ in range(self.num_layers): - key_blocks = torch.empty( - size=(self.num_gpu_blocks, *key_block_shape), - dtype=self.dtype, - device="cuda", - ) - value_blocks = torch.empty( - size=(self.num_gpu_blocks, *value_block_shape), - dtype=self.dtype, - device="cuda", - ) - gpu_cache.append((key_blocks, value_blocks)) - return gpu_cache - - def allocate_cpu_cache(self) -> List[KVCache]: - cpu_cache: List[KVCache] = [] - key_block_shape = self.get_key_block_shape() - value_block_shape = self.get_value_block_shape() - pin_memory = is_pin_memory_available() - for _ in range(self.num_layers): - key_blocks = torch.empty( - size=(self.num_cpu_blocks, *key_block_shape), - dtype=self.dtype, - pin_memory=pin_memory, - device="cpu", - ) - value_blocks = torch.empty( - size=(self.num_cpu_blocks, *value_block_shape), - dtype=self.dtype, - pin_memory=pin_memory, - device="cpu", - ) - cpu_cache.append((key_blocks, value_blocks)) - return cpu_cache - - def _swap( - self, - src: List[KVCache], - dst: List[KVCache], - src_to_dst: Dict[int, int], - ) -> None: - from vllm._C import cache_ops + self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - for i in range(self.num_layers): - src_key_cache, src_value_cache = src[i] - dst_key_cache, dst_value_cache = dst[i] - # Copy the key blocks. - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - # Copy the value blocks. - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[torch.Tensor]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + pin_memory = is_pin_memory_available() if device == "cpu" else False + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) + return kv_cache def swap_in(self, src_to_dst: Dict[int, int]) -> None: - self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) + for i in range(self.num_layers): + self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], + src_to_dst) def swap_out(self, src_to_dst: Dict[int, int]) -> None: - self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) + for i in range(self.num_layers): + self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], + src_to_dst) def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: - from vllm._C import cache_ops - - key_caches = [key_cache for key_cache, _ in self.gpu_cache] - value_caches = [value_cache for _, value_cache in self.gpu_cache] - # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. - cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) + self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod def get_cache_block_size( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b8eeb51379f49..6e1fb4ede815c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,10 +6,11 @@ import torch import torch.nn as nn +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.model_executor import InputMetadata, SamplingMetadata +from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.communication_op import ( @@ -28,7 +29,6 @@ logger = init_logger(__name__) -KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 @@ -85,6 +85,9 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model(self.model_config, @@ -127,8 +130,8 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], - List[int], List[int], Set[LoRARequest]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -216,7 +219,7 @@ def _prepare_prompt( slot_mapping.append(slot) max_subquery_len = max(subquery_lens) - max_seq_len = max(prompt_lens) + max_prompt_len = max(prompt_lens) num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 @@ -270,7 +273,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - input_metadata = InputMetadata( + attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens, @@ -279,7 +282,7 @@ def _prepare_prompt( num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, - max_seq_len=max_seq_len, + max_prompt_len=max_prompt_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, @@ -287,15 +290,15 @@ def _prepare_prompt( use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, input_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], - Set[LoRARequest]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -401,7 +404,7 @@ def _prepare_decode( device=self.device, ) - input_metadata = InputMetadata( + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, @@ -410,7 +413,7 @@ def _prepare_decode( num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, - max_seq_len=None, + max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens, @@ -418,7 +421,7 @@ def _prepare_decode( use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, input_metadata, + return (input_tokens, input_positions, attn_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_sample( @@ -522,7 +525,7 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[int], LoRAMapping]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or @@ -530,11 +533,11 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, + (input_tokens, input_positions, attn_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] @@ -560,7 +563,7 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, } - metadata_dict.update(input_metadata.asdict_zerocopy()) + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) @@ -570,7 +573,7 @@ def prepare_input_tensors( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - input_metadata = InputMetadata(**metadata_dict) + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -581,16 +584,16 @@ def prepare_input_tensors( perform_sampling=False, ) - return (input_tokens, input_positions, input_metadata, + return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_metadata, sampling_metadata, + (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) @@ -598,7 +601,7 @@ def execute_model( self.set_active_loras(lora_requests, lora_mapping) # Execute the model. - if input_metadata.use_cuda_graph: + if attn_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -607,7 +610,7 @@ def execute_model( input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Compute the logits. @@ -673,7 +676,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [(None, None)] * num_layers + kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() return @@ -705,7 +708,7 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[KVCache]) -> None: + def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number @@ -759,8 +762,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - # Create dummy input_metadata. - input_metadata = InputMetadata( + # Create dummy attn_metadata. + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, @@ -769,7 +772,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, - max_seq_len=None, + max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens[:batch_size], @@ -790,7 +793,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_tokens[:batch_size], input_positions[:batch_size], kv_caches, - input_metadata, + attn_metadata, memory_pool=self.graph_memory_pool, ) self.graph_memory_pool = graph_runner.graph.pool() @@ -826,8 +829,8 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, memory_pool, ) -> None: assert self.graph is None @@ -839,7 +842,7 @@ def capture( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, ) torch.cuda.synchronize() @@ -853,7 +856,7 @@ def capture( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, ) torch.cuda.synchronize() @@ -862,9 +865,9 @@ def capture( "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, - "slot_mapping": input_metadata.slot_mapping, - "context_lens": input_metadata.context_lens, - "block_tables": input_metadata.block_tables, + "slot_mapping": attn_metadata.slot_mapping, + "context_lens": attn_metadata.context_lens, + "block_tables": attn_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -873,8 +876,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -882,11 +885,11 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, non_blocking=True) - self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d8999dc172127..2f9398a701b45 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -128,6 +128,9 @@ def profile_num_available_blocks( # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes( block_size, cache_dtype)