From 3b673b93e9168c6496dd54063c16e81f9b1d0d4b Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Tue, 10 Sep 2024 12:16:54 +0200 Subject: [PATCH] Port flat PA from habana_next to habana_main (#169) FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--------- Co-authored-by: Michal Adamczyk Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com> Co-authored-by: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Co-authored-by: Jan Kaniecki --- README_GAUDI.md | 22 +- .../getting_started/gaudi-installation.rst | 14 +- vllm/attention/backends/habana_attn.py | 136 ++----- vllm/attention/ops/habana_paged_attn.py | 51 +-- vllm/hpu/ops.py | 114 +++--- vllm/hpu/utils.py | 7 +- vllm/worker/habana_model_runner.py | 365 +++++++++++------- 7 files changed, 330 insertions(+), 379 deletions(-) diff --git a/README_GAUDI.md b/README_GAUDI.md index 91bcbe49405eb..5109f7ddf9927 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -455,12 +455,12 @@ Environment variables - `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - `{phase}` is either `PROMPT` or `DECODE` - - `{dim}` is either `BS` or `SEQ` + - `{dim}` is either `BS`, `SEQ` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` - Default values: - Prompt: - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` - - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): `min(max_num_seqs, 64)` - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): @@ -468,20 +468,20 @@ Environment variables - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): - `1024` + `max_model_len` - Decode: - - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): - `128` + `min(max_num_seqs, 32)` - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): `max_num_seqs` - - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): - `block_size` - - sequence length step - (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` - - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): - `2048` + - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): + `128` + - block size step + (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128` + - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): + `max(128, (max_num_seqs*max_model_len)/block_size)` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index b3234d10b3115..ed3beabb2c8aa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -335,19 +335,19 @@ Environment variables - Prompt: - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` - Decode: - - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` - - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 2259630fa10b7..20b0f2bc7630b 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -58,58 +58,14 @@ def copy_blocks( @dataclass -class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata): - """Metadata for HabanaAttentionbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ +class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata): + """Metadata for HabanaAttentionbackend.""" # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. + attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. - max_query_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[torch.Tensor] = None - class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ @@ -229,60 +185,48 @@ def forward( if attn_metadata.is_prompt: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) - else: - attn_bias = None - - query_shape = (batch_size, seq_len, self.num_heads, - self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, - self.head_size) - out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - ) - output = out.reshape(batch_size, seq_len, hidden_size) + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) else: - # prefix-enabled attention - output = HabanaPagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.seq_lens_tensor, - attn_metadata.context_lens_tensor, - attn_metadata.max_query_len, - self.alibi_slopes, - ) + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HabanaPagedAttention.forward_decode( - query, key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, self.kv_cache_dtype, - self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale, self.matmul_qk, self.softmax, self.matmul_av, - self.k_cache, self.v_cache) + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 9602886299c47..cab8d7abe95fd 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -16,16 +16,9 @@ @dataclass class HabanaPagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of sequences (entire tokens seen so far) per - # sequence. - seq_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] class HabanaPagedAttention: @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, slot_mapping, kv_cache_dtype, is_prompt) @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) -> torch.Tensor: - block_size = value_cache.shape[1] - return ops.paged_attention_v1( - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - alibi_slopes, - kv_cache_dtype, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) @staticmethod def forward_prefix( diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index bacb755b39393..b2705429906c4 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -import os from typing import Optional import habana_frameworks.torch as htorch @@ -29,72 +28,57 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - - -def fetch_from_cache(cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] - - -def paged_attention_v1(query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - alibi_slopes=None, - kv_cache_dtype=None, - matmul_qk_op=torch.matmul, - softmax_op=torch.softmax, - matmul_av_op=torch.matmul, - k_cache_cls=None, - v_cache_cls=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, - seq_len * block_size, - dtype=torch.int32, - device=key_cache.device).view(1, -1).expand( - batch_size, -1).ge(context_lens.view(-1, 1)).view( - batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else \ - k_cache_cls.fetch_from_cache - keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: + +def batch2block(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block2batch(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block_softmax(batch_size, attn, block_mapping): + attn.sub_(10.0) + attn = attn.exp_() + sums = attn.sum(dim=-1).unsqueeze(-1) + sums = block2batch(sums, block_mapping) + sums = batch2block(sums, block_mapping) + sums.add_(1.0e-12) + attn.div_(sums) + return attn + + +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, + values_fetch_func): + batch_size = query.size(0) + q_heads = query.size(1) + kv_heads = key_cache.size(2) + + query = batch2block(scale * query, block_mapping).unsqueeze(-2) + key = keys_fetch_func(key_cache, block_list).transpose(1, 2) + value = values_fetch_func(value_cache, block_list).transpose(1, 2) + block_bias = block_bias.view(key.size(0), 1, 1, -1) + + if kv_heads != q_heads: + block_bias = block_bias.unsqueeze(1) query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) - if alibi_slopes is not None: - attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, - -attn_weights.size(3):]) - attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - - fetch_values = fetch_from_cache if v_cache_cls is None else \ - v_cache_cls.fetch_from_cache - values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + key = key.transpose(3, 4) else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) + key = key.transpose(2, 3) + + attn = matmul_qk_op(query, key) + block_bias + attn = block_softmax(batch_size, attn, block_mapping) + attn = matmul_av_op(attn, value) + attn = block2batch(attn, block_mapping) + attn = attn.squeeze(-2) + if kv_heads != q_heads: + attn = attn.flatten(1, 2) + return attn def silu_and_mul(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 3d9c7cb1c4c22..13204b83d5742 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -57,8 +57,5 @@ def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_offset) return cache - def fetch_from_cache(self, cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] + def fetch_from_cache(self, cache, blocks): + return cache.index_select(0, blocks) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a4ade587db089..a6bd5e5f68745 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -51,29 +51,47 @@ logger = init_logger(__name__) +_TYPE_CACHE = {} # These values are assumed to be zero in several places. # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 -_TYPE_CACHE = {} + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) def read_bucket_settings(phase: str, dim: str, **defaults): """Read bucketing configuration from env variables. phase is either 'prompt' or 'decode' - dim is either 'bs' or 'block' + dim is either 'bs', 'seq' or 'block' param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + default_values = [defaults[p] for p in params] values = [ - int( - os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), - defaults[p])) for p in params + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] + for e, v, d in zip(env_vars, values, defaults): + logger.info('%s=%s (default:%s)', e, v, d) return values @@ -103,9 +121,9 @@ def warmup_range(config: Tuple[int, int, int]): return list(filter(lambda bucket: bucket >= bmin, buckets)) -def warmup_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): buckets = list( itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config))) @@ -150,6 +168,19 @@ def warmup_buckets(bs_bucket_config, return captured_buckets, omitted_buckets +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): + buckets = [] + for bs in warmup_range(bs_bucket_config): + for blocks in warmup_range(blocks_bucket_config): + if blocks < bs: + continue + if blocks > max_blocks: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + def next_pow2(value: int, base: int): res = base while value > 1: @@ -169,22 +200,6 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if to_override is None: - to_override = {} - if obj is None: - return None - fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -195,13 +210,19 @@ def align_workers(value, op): return value_t.item() +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding + + class HpuModelAdapter(): - def __init__(self, model, enforce_eager): + def __init__(self, model, block_size, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] - + self.block_size = block_size if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -225,22 +246,45 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, mask = causal_mask.logical_or(len_mask) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - #FIXME: Restore sliding window support - #if self.sliding_window is not None: attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot( + metadata.block_mapping.to(torch.long), + num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) + else: + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) + return attn_metadata + def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], - input_ids.size(0), - input_ids.size(1), - input_ids.device, - torch.bfloat16) + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, torch.bfloat16) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -536,7 +580,9 @@ def load_model(self) -> None: # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( - self.model, enforce_eager=self.enforce_eager) + self.model, + self.block_size, + enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -553,73 +599,48 @@ def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 if self.lora_config and \ max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size - self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', - 'bs', - min=1, - step=32, - max=min( - self.max_num_seqs, - max_bucket_cfg)) + blocks_step = 128 + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=1, - step=128, + min=align_bs(32), + step=align_bs(32), max=self.max_num_seqs) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', min=self.block_size, step=self.block_size, - max=1024) - self.decode_seq_bucket_cfg = read_bucket_settings('decode', - 'seq', - min=self.block_size, - step=self.block_size, - max=2048) + max=max_prompt_seq) + self.decode_block_bucket_cfg = read_bucket_settings( + 'decode', + 'block', + min=blocks_step, + step=blocks_step, + max=max(blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - if self.lora_config: - self.prompt_buckets[:] = [ - bucket for bucket in self.prompt_buckets - if self._is_valid_bucket(bucket) - ] - - msg = (f"Generated {len(self.prompt_buckets)} " - f"prompt buckets: {list(sorted(self.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " - f"seq:{self.decode_seq_bucket_cfg}") - logger.info(msg) - self.decode_buckets, _ = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) - if self.lora_config: - self.decode_buckets[:] = [ - bucket for bucket in self.decode_buckets - if self._is_valid_bucket(bucket) - ] - msg = (f"Generated {len(self.decode_buckets)} decode buckets: " - f"{list(sorted(self.decode_buckets))}") + f"block:{self.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -735,10 +756,6 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " @@ -748,7 +765,6 @@ def _prepare_prompt( else: multi_modal_input = None - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) @@ -814,37 +830,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - block_tables = make_tensor_with_pad(prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seq_lens=seq_lens, + block_list=None, + block_mapping=None, + block_usage=None, + attn_bias=None, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -950,32 +946,50 @@ def _prepare_decode( s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl ] for sl in slot_mapping] + num_decode_tokens = sum(seq_lens) + + blocks_used = [len(bt) for bt in block_tables] + block_list = list(itertools.chain(*block_tables)) + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) + + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + + block_bucket_size = find_bucket(len(block_list), + self.decode_block_bucket_cfg) + block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, 0) + block_usage = pad_list(block_usage, block_bucket_size, 0) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=torch.bfloat16, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - num_decode_tokens = sum(seq_lens) - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=False, + block_list=block_list, + block_mapping=block_mapping, + block_usage=block_usage, + attn_bias=None, + seq_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1163,7 +1177,7 @@ def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: - return attn_metadata.block_tables.size(1) * self.block_size + return attn_metadata.block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: @@ -1187,8 +1201,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'block_tables', 'seq_lens_tensor', 'attn_bias', 'slot_mapping', - 'is_prompt' + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt' ]) return attention_metadata @@ -1222,9 +1236,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = self.prompt_seq_bucket_cfg[-1] - if self.lora_config: - max_seq_len = self.max_num_batched_tokens // max_batch_size + max_seq_len = min(self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, @@ -1277,21 +1290,34 @@ def warmup_scenario(self, [0] * batch_size * seq_len, ) self.set_active_loras(set(), lora_mapping) - seqs = [ - self.create_dummy_seq_group_metadata( - i, - seq_len, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i in range(batch_size) - ] + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, kv_caches, warmup_mode=False) torch.hpu.synchronize() - self.profiler.end() gc.collect() def remove_all_loras(self): @@ -1328,9 +1354,12 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"seq_len:{seq_len} " + f"{dim}:{seq_len} " f"free_mem:{free_mem}") logger.info(msg) @@ -1390,6 +1419,8 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 msg = (f'{phase} captured:{len(graphed)} ' f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' @@ -1402,6 +1433,42 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info("Skipping warmup...") return self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] + + msg = ( + f"Generated {len(self.prompt_buckets)} " + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.decode_buckets = generate_decode_buckets( + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets))) + start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter()