From 5cf8441311b341e60d6538c442656e48ab38d230 Mon Sep 17 00:00:00 2001 From: Dominika Olszewska Date: Tue, 10 Sep 2024 12:16:54 +0200 Subject: [PATCH 1/9] Port flat PA from habana_next to habana_main (#169) FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--------- Co-authored-by: Michal Adamczyk Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com> Co-authored-by: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Co-authored-by: Jan Kaniecki --- README_GAUDI.md | 22 +- .../getting_started/gaudi-installation.rst | 14 +- vllm/attention/backends/habana_attn.py | 136 ++----- vllm/attention/ops/habana_paged_attn.py | 51 +-- vllm/hpu/ops.py | 114 +++--- vllm/hpu/utils.py | 7 +- vllm/worker/habana_model_runner.py | 365 +++++++++++------- 7 files changed, 330 insertions(+), 379 deletions(-) diff --git a/README_GAUDI.md b/README_GAUDI.md index 91bcbe49405eb..5109f7ddf9927 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -455,12 +455,12 @@ Environment variables - `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - `{phase}` is either `PROMPT` or `DECODE` - - `{dim}` is either `BS` or `SEQ` + - `{dim}` is either `BS`, `SEQ` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` - Default values: - Prompt: - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` - - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): `min(max_num_seqs, 64)` - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): @@ -468,20 +468,20 @@ Environment variables - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): - `1024` + `max_model_len` - Decode: - - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): - `128` + `min(max_num_seqs, 32)` - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): `max_num_seqs` - - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): - `block_size` - - sequence length step - (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` - - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): - `2048` + - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): + `128` + - block size step + (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128` + - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): + `max(128, (max_num_seqs*max_model_len)/block_size)` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index b3234d10b3115..ed3beabb2c8aa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -335,19 +335,19 @@ Environment variables - Prompt: - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` - Decode: - - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` - - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` - - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 2259630fa10b7..20b0f2bc7630b 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -58,58 +58,14 @@ def copy_blocks( @dataclass -class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata): - """Metadata for HabanaAttentionbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ +class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata): + """Metadata for HabanaAttentionbackend.""" # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. + attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. - max_query_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[torch.Tensor] = None - class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ @@ -229,60 +185,48 @@ def forward( if attn_metadata.is_prompt: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) - else: - attn_bias = None - - query_shape = (batch_size, seq_len, self.num_heads, - self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, - self.head_size) - out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - ) - output = out.reshape(batch_size, seq_len, hidden_size) + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) else: - # prefix-enabled attention - output = HabanaPagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.seq_lens_tensor, - attn_metadata.context_lens_tensor, - attn_metadata.max_query_len, - self.alibi_slopes, - ) + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HabanaPagedAttention.forward_decode( - query, key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, self.kv_cache_dtype, - self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale, self.matmul_qk, self.softmax, self.matmul_av, - self.k_cache, self.v_cache) + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 9602886299c47..cab8d7abe95fd 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -16,16 +16,9 @@ @dataclass class HabanaPagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of sequences (entire tokens seen so far) per - # sequence. - seq_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] class HabanaPagedAttention: @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, slot_mapping, kv_cache_dtype, is_prompt) @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) -> torch.Tensor: - block_size = value_cache.shape[1] - return ops.paged_attention_v1( - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - alibi_slopes, - kv_cache_dtype, - matmul_qk_op, - softmax_op, - matmul_av_op, - k_cache_cls, - v_cache_cls, - ) + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) @staticmethod def forward_prefix( diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index bacb755b39393..b2705429906c4 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -import os from typing import Optional import habana_frameworks.torch as htorch @@ -29,72 +28,57 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - - -def fetch_from_cache(cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] - - -def paged_attention_v1(query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - alibi_slopes=None, - kv_cache_dtype=None, - matmul_qk_op=torch.matmul, - softmax_op=torch.softmax, - matmul_av_op=torch.matmul, - k_cache_cls=None, - v_cache_cls=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, - seq_len * block_size, - dtype=torch.int32, - device=key_cache.device).view(1, -1).expand( - batch_size, -1).ge(context_lens.view(-1, 1)).view( - batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else \ - k_cache_cls.fetch_from_cache - keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: + +def batch2block(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block2batch(tensor, block_mapping): + shape = tuple(tensor.shape) + return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) + + +def block_softmax(batch_size, attn, block_mapping): + attn.sub_(10.0) + attn = attn.exp_() + sums = attn.sum(dim=-1).unsqueeze(-1) + sums = block2batch(sums, block_mapping) + sums = batch2block(sums, block_mapping) + sums.add_(1.0e-12) + attn.div_(sums) + return attn + + +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, + values_fetch_func): + batch_size = query.size(0) + q_heads = query.size(1) + kv_heads = key_cache.size(2) + + query = batch2block(scale * query, block_mapping).unsqueeze(-2) + key = keys_fetch_func(key_cache, block_list).transpose(1, 2) + value = values_fetch_func(value_cache, block_list).transpose(1, 2) + block_bias = block_bias.view(key.size(0), 1, 1, -1) + + if kv_heads != q_heads: + block_bias = block_bias.unsqueeze(1) query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) - if alibi_slopes is not None: - attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, - -attn_weights.size(3):]) - attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - - fetch_values = fetch_from_cache if v_cache_cls is None else \ - v_cache_cls.fetch_from_cache - values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + key = key.transpose(3, 4) else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) + key = key.transpose(2, 3) + + attn = matmul_qk_op(query, key) + block_bias + attn = block_softmax(batch_size, attn, block_mapping) + attn = matmul_av_op(attn, value) + attn = block2batch(attn, block_mapping) + attn = attn.squeeze(-2) + if kv_heads != q_heads: + attn = attn.flatten(1, 2) + return attn def silu_and_mul(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 3d9c7cb1c4c22..13204b83d5742 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -57,8 +57,5 @@ def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_offset) return cache - def fetch_from_cache(self, cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] + def fetch_from_cache(self, cache, blocks): + return cache.index_select(0, blocks) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a4ade587db089..a6bd5e5f68745 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -51,29 +51,47 @@ logger = init_logger(__name__) +_TYPE_CACHE = {} # These values are assumed to be zero in several places. # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 -_TYPE_CACHE = {} + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) def read_bucket_settings(phase: str, dim: str, **defaults): """Read bucketing configuration from env variables. phase is either 'prompt' or 'decode' - dim is either 'bs' or 'block' + dim is either 'bs', 'seq' or 'block' param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + default_values = [defaults[p] for p in params] values = [ - int( - os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), - defaults[p])) for p in params + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] + for e, v, d in zip(env_vars, values, defaults): + logger.info('%s=%s (default:%s)', e, v, d) return values @@ -103,9 +121,9 @@ def warmup_range(config: Tuple[int, int, int]): return list(filter(lambda bucket: bucket >= bmin, buckets)) -def warmup_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): buckets = list( itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config))) @@ -150,6 +168,19 @@ def warmup_buckets(bs_bucket_config, return captured_buckets, omitted_buckets +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): + buckets = [] + for bs in warmup_range(bs_bucket_config): + for blocks in warmup_range(blocks_bucket_config): + if blocks < bs: + continue + if blocks > max_blocks: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + def next_pow2(value: int, base: int): res = base while value > 1: @@ -169,22 +200,6 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if to_override is None: - to_override = {} - if obj is None: - return None - fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -195,13 +210,19 @@ def align_workers(value, op): return value_t.item() +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding + + class HpuModelAdapter(): - def __init__(self, model, enforce_eager): + def __init__(self, model, block_size, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] - + self.block_size = block_size if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -225,22 +246,45 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, mask = causal_mask.logical_or(len_mask) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - #FIXME: Restore sliding window support - #if self.sliding_window is not None: attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot( + metadata.block_mapping.to(torch.long), + num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) + else: + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) + return attn_metadata + def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], - input_ids.size(0), - input_ids.size(1), - input_ids.device, - torch.bfloat16) + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, torch.bfloat16) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -536,7 +580,9 @@ def load_model(self) -> None: # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( - self.model, enforce_eager=self.enforce_eager) + self.model, + self.block_size, + enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -553,73 +599,48 @@ def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 if self.lora_config and \ max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size - self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', - 'bs', - min=1, - step=32, - max=min( - self.max_num_seqs, - max_bucket_cfg)) + blocks_step = 128 + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=1, - step=128, + min=align_bs(32), + step=align_bs(32), max=self.max_num_seqs) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', min=self.block_size, step=self.block_size, - max=1024) - self.decode_seq_bucket_cfg = read_bucket_settings('decode', - 'seq', - min=self.block_size, - step=self.block_size, - max=2048) + max=max_prompt_seq) + self.decode_block_bucket_cfg = read_bucket_settings( + 'decode', + 'block', + min=blocks_step, + step=blocks_step, + max=max(blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - if self.lora_config: - self.prompt_buckets[:] = [ - bucket for bucket in self.prompt_buckets - if self._is_valid_bucket(bucket) - ] - - msg = (f"Generated {len(self.prompt_buckets)} " - f"prompt buckets: {list(sorted(self.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " - f"seq:{self.decode_seq_bucket_cfg}") - logger.info(msg) - self.decode_buckets, _ = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) - if self.lora_config: - self.decode_buckets[:] = [ - bucket for bucket in self.decode_buckets - if self._is_valid_bucket(bucket) - ] - msg = (f"Generated {len(self.decode_buckets)} decode buckets: " - f"{list(sorted(self.decode_buckets))}") + f"block:{self.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -735,10 +756,6 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " @@ -748,7 +765,6 @@ def _prepare_prompt( else: multi_modal_input = None - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) @@ -814,37 +830,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - block_tables = make_tensor_with_pad(prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - seq_lens=seq_lens, + block_list=None, + block_mapping=None, + block_usage=None, + attn_bias=None, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -950,32 +946,50 @@ def _prepare_decode( s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl ] for sl in slot_mapping] + num_decode_tokens = sum(seq_lens) + + blocks_used = [len(bt) for bt in block_tables] + block_list = list(itertools.chain(*block_tables)) + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) + + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + + block_bucket_size = find_bucket(len(block_list), + self.decode_block_bucket_cfg) + block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, 0) + block_usage = pad_list(block_usage, block_bucket_size, 0) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=torch.bfloat16, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - num_decode_tokens = sum(seq_lens) - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=False, + block_list=block_list, + block_mapping=block_mapping, + block_usage=block_usage, + attn_bias=None, + seq_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1163,7 +1177,7 @@ def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: - return attn_metadata.block_tables.size(1) * self.block_size + return attn_metadata.block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: @@ -1187,8 +1201,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'block_tables', 'seq_lens_tensor', 'attn_bias', 'slot_mapping', - 'is_prompt' + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt' ]) return attention_metadata @@ -1222,9 +1236,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = self.prompt_seq_bucket_cfg[-1] - if self.lora_config: - max_seq_len = self.max_num_batched_tokens // max_batch_size + max_seq_len = min(self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, @@ -1277,21 +1290,34 @@ def warmup_scenario(self, [0] * batch_size * seq_len, ) self.set_active_loras(set(), lora_mapping) - seqs = [ - self.create_dummy_seq_group_metadata( - i, - seq_len, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i in range(batch_size) - ] + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, kv_caches, warmup_mode=False) torch.hpu.synchronize() - self.profiler.end() gc.collect() def remove_all_loras(self): @@ -1328,9 +1354,12 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " - f"seq_len:{seq_len} " + f"{dim}:{seq_len} " f"free_mem:{free_mem}") logger.info(msg) @@ -1390,6 +1419,8 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 msg = (f'{phase} captured:{len(graphed)} ' f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' @@ -1402,6 +1433,42 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info("Skipping warmup...") return self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] + + msg = ( + f"Generated {len(self.prompt_buckets)} " + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.decode_buckets = generate_decode_buckets( + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets))) + start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() From 4052bdb728ba3bbddca82af1a71574c8db706179 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 10 Sep 2024 15:04:34 +0200 Subject: [PATCH 2/9] Add disable_tensor_cache=True to HPUGraph capture (#252) RuntimeErrors are not observed anymore on habana_main when disable_tensor_cache is used. This PR enables disable_tensor_cache. --- vllm/worker/habana_model_runner.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a6bd5e5f68745..dfc2ee152076f 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -576,8 +576,6 @@ def load_model(self) -> None: htcore.mark_step() torch.hpu.synchronize() - # FIXME: Running with disable_tensor_cache=True causes - # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( self.model, @@ -1576,10 +1574,9 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): - return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( - *args, ** - kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter( - *args, **kwargs) + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) class HabanaProfilerCounterHelper(): From 69df1e7e3f6b580945ce0d0cab88233829dae205 Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Tue, 10 Sep 2024 15:43:20 +0200 Subject: [PATCH 3/9] Fix dispersed slots (#261) On habana_main the slots are calculated by adding an offset to the block which breaks the check for _PAD_SLOT_ID. Reworked it so that in case of _PAD_BLOCK_ID we're automatically inserting the right value. --- vllm/worker/habana_model_runner.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index dfc2ee152076f..8d6c386a9975e 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -887,6 +887,9 @@ def _prepare_decode( self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) + dummy_slots = itertools.cycle( + range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -916,8 +919,11 @@ def _prepare_decode( block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset + if block_number == _PAD_BLOCK_ID: + slot = next(dummy_slots) + else: + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) @@ -938,12 +944,6 @@ def _prepare_decode( dtype=torch.long, device=self.device) - dummy_slots = itertools.cycle( - range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) - slot_mapping = [[ - s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl - ] for sl in slot_mapping] - num_decode_tokens = sum(seq_lens) blocks_used = [len(bt) for bt in block_tables] From 53f96b784980b60ca12418b39c4785210931fb09 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 10 Sep 2024 15:53:11 +0200 Subject: [PATCH 4/9] Skip compilation warnings during warmup phase (#262) --- vllm/worker/habana_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 8d6c386a9975e..b6218f3cc4cfb 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1314,7 +1314,7 @@ def warmup_scenario(self, torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=False) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() gc.collect() From 2091161b4a2e3acaa531d1a1a3c0cba65bb50b21 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz <160237065+adobrzyniewicz-habana@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:15:09 +0200 Subject: [PATCH 5/9] Port PT Profiler to habana_main (#256) Porting PT Profiler from: https://github.com/HabanaAI/vllm-fork/commit/81a23a708195faef6167919890cefa225a721907 and https://github.com/HabanaAI/vllm-fork/commit/e805b885d32a749d9409f13b6446895d13e8b885 --- vllm/worker/habana_model_runner.py | 46 ++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index b6218f3cc4cfb..2360e39fcba28 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -210,6 +210,26 @@ def align_workers(value, op): return value_t.item() +def setup_profiler(): + schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) + DEVICE = 'hpu' + activities = [torch.profiler.ProfilerActivity.CPU] + activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == + 'hpu' else []) + #from habana_frameworks.torch.activity_profiler import DebugActivity + #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] + + profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + #debug_activities=debug_activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler('.', + use_gzip=True), + record_shapes=False, + with_stack=True) + return profiler + + def pad_list(list, k, v): target_len = round_up(len(list), k) padding = target_len - len(list) @@ -1237,11 +1257,7 @@ def profile_run(self) -> None: max_seq_len = min(self.prompt_seq_bucket_cfg[-1], self.max_num_batched_tokens // max_batch_size) - self.warmup_scenario(max_batch_size, - max_seq_len, - True, - kv_caches, - is_profile_run=True) + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) return def warmup_scenario(self, @@ -1281,7 +1297,7 @@ def warmup_scenario(self, for idx in range(max_num_seqs) ] self.profiler.start('internal', scenario_name) - times = 3 if use_graphs else 1 + times = 3 if use_graphs or is_profile_run else 1 if self.lora_config and not is_profile_run: lora_mapping = LoRAMapping( [0] * batch_size * seq_len, @@ -1312,10 +1328,19 @@ def warmup_scenario(self, for i, b in enumerate(blocks) ] torch.hpu.synchronize() + profiler = None + if is_profile_run and self.is_driver_worker: + profiler = setup_profiler() + profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() + if profiler: + profiler.step() + if profiler: + profiler.stop() + self.profiler.end() gc.collect() def remove_all_loras(self): @@ -1427,6 +1452,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + if profile := os.environ.get('VLLM_PT_PROFILE', None): + phase, bs, seq_len, graph = profile.split('_') + is_prompt = phase == 'prompt' + graphs = graph == 't' + if graphs: + self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) + raise AssertionError("Finished profiling") if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true': logger.info("Skipping warmup...") return From b776d5e8fa287018e7e373e6588f2d15176e0d72 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Wed, 11 Sep 2024 12:49:20 +0300 Subject: [PATCH 6/9] Fix LoRA test by handling mask creation inside the test --- tests/lora/test_lora_hpu.py | 93 +++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/tests/lora/test_lora_hpu.py b/tests/lora/test_lora_hpu.py index ddbab66e166b3..01b6472745e1c 100644 --- a/tests/lora/test_lora_hpu.py +++ b/tests/lora/test_lora_hpu.py @@ -1,6 +1,7 @@ import pytest import torch +from vllm.hpu.ops import LoraMask from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice from .utils import DummyLoRAManager @@ -19,7 +20,19 @@ torch.float16: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -MAX_LORAS = 8 + + +def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank, + lora_dtype): + indices = indices.view(-1, 1) + mask = torch.arange(max_loras * max_lora_rank, device=indices.device) + mask = mask.view(1, -1) + mask = ((mask >= ((indices) * max_lora_rank)) * + (mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype) + mask = mask.view(batch_size, 1, + -1).expand(batch_size, seq_len, + -1).reshape(batch_size * seq_len, -1) + return mask @pytest.mark.parametrize("m", TENSOR_SIZES) @@ -39,32 +52,40 @@ def test_apply_lora(m, n, k, rank, dtype) -> None: input = torch.rand(k, n, device="hpu", dtype=dtype) expected = input @ lora.lora_a @ lora.lora_b * lora.scaling - lora_a_stack = torch.zeros(MAX_LORAS + 1, + lora_a_stack = torch.zeros(8, 1, lora.lora_a.shape[1], lora.lora_a.shape[0], device="hpu", dtype=dtype) - lora_b_stack = torch.zeros(MAX_LORAS + 1, + lora_b_stack = torch.zeros(8, 1, lora.lora_b.shape[1], lora.lora_b.shape[0], device="hpu", dtype=dtype) - for i in range(MAX_LORAS): + for i in range(lora_a_stack.shape[0]): lora_a_stack[i][0] = lora.lora_a.T lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T output = torch.zeros(k, m, device="hpu", dtype=dtype) - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), - output) + indices = torch.randint(0, + lora_a_stack.shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora(input, lora_a_stack, lora_b_stack, indices, output) + rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.full((len(input), ), -1, device="hpu"), output) + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora(input, lora_a_stack, lora_b_stack, indices, output) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -99,7 +120,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: dim=1) lora_a_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_1.lora_a.shape[1], lora_1.lora_a.shape[0], @@ -107,31 +128,38 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: dtype=dtype) for i in range(2) ] lora_b_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_1.lora_b.shape[1], lora_1.lora_b.shape[0], device="hpu", dtype=dtype) for i in range(2) ] - for i in range(MAX_LORAS): + for i in range(lora_a_stacks[0].shape[0]): lora_a_stacks[0][i][0] = lora_1.lora_a.T lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T lora_a_stacks[1][i][0] = lora_2.lora_a.T lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T output = torch.zeros(k, m, device="hpu", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, - (m // 2, m // 2)) + indices = torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, + output, (m // 2, m // 2)) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="hpu"), + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, output, (m // 2, m // 2)) assert torch.allclose(torch.zeros_like(output), output) @@ -166,14 +194,14 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: dim=1) lora_a_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_q.lora_a.shape[1], lora_q.lora_a.shape[0], device="hpu", dtype=dtype) ] + [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_k.lora_a.shape[1], lora_k.lora_a.shape[0], @@ -181,21 +209,21 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: dtype=dtype) for i in range(2) ] lora_b_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_q.lora_b.shape[1], lora_q.lora_b.shape[0], device="hpu", dtype=dtype) ] + [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_k.lora_b.shape[1], lora_k.lora_b.shape[0], device="hpu", dtype=dtype) for i in range(2) ] - for i in range(MAX_LORAS): + for i in range(lora_a_stacks[0].shape[0]): lora_a_stacks[0][i][0] = lora_q.lora_a.T lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T lora_a_stacks[1][i][0] = lora_k.lora_a.T @@ -204,17 +232,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, - (qkv[0], qkv[1], qkv[2])) + indices = torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, + output, (qkv[0], qkv[1], qkv[2])) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="hpu"), + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, output, (qkv[0], qkv[1], qkv[2])) assert torch.allclose(torch.zeros_like(output), output) From f858d4359657db1ea01f39e8a8b39ec68076d6a6 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:57:03 +0530 Subject: [PATCH 7/9] Attn MetaData dtype should be same as model dtype (#271) Attn MetaData was hard coded to bfloat16, leading to a runtime error for float32 model instantiation. --- vllm/worker/habana_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 2360e39fcba28..55f205915ea8c 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -238,11 +238,12 @@ def pad_list(list, k, v): class HpuModelAdapter(): - def __init__(self, model, block_size, enforce_eager): + def __init__(self, model, block_size, dtype, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] self.block_size = block_size + self.dtype = dtype if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -304,7 +305,7 @@ def forward(self, *args, **kwargs): input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, torch.bfloat16) + input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -600,6 +601,7 @@ def load_model(self) -> None: self.model = _maybe_wrap_in_hpu_graph( self.model, self.block_size, + dtype=self.model_config.dtype, enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) From acf7d548ee0352c5482d0c424ddb4a0558007ef7 Mon Sep 17 00:00:00 2001 From: Dudi Lester <160421192+dudilester@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:42:31 +0300 Subject: [PATCH 8/9] Support Mixtral quantization using INC (#267) --- vllm/hpu/ops.py | 88 ++++++++++++------- vllm/model_executor/layers/fused_moe/layer.py | 42 ++++++--- .../model_executor/layers/quantization/inc.py | 6 +- vllm/model_executor/model_loader/utils.py | 2 +- 4 files changed, 96 insertions(+), 42 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index b2705429906c4..3d76c36f2648b 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -86,36 +86,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] -def static_fused_moe(hidden_states, w1, w2, score, topk): - B, D = hidden_states.shape - num_experts = w1.shape[0] - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - htorch.core.mark_step() - - for expert_idx in range(num_experts): - w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) - w_output = silu_and_mul(w_output) - w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - final_hidden_states += w_output * padded_weights[expert_idx] - - return final_hidden_states.view(-1, D) - - #TODO: remove after fusedsdpa fix for query_head != kv_head def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -252,3 +222,61 @@ def dispatch_bgmv_embedding( wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) out = x @ wb y += out * scale + + +class MoeMatmul(torch.nn.Module): + + def __init__(self): + super().__init__() + + def set_weight(self, w): + self.weight = w + + def calc(self, state, expert_id, w): + self.weight = w[expert_id].transpose(0, 1) + return self.forward(state) + + def forward(self, state): + return torch.matmul(state, self.weight) + + +class StaticFusedMOE(torch.nn.Module): + + def __init__(self, num_total_experts): + super().__init__() + self.w13_list = torch.nn.ModuleList( + [MoeMatmul() for _ in range(num_total_experts)]) + self.w2_list = torch.nn.ModuleList( + [MoeMatmul() for _ in range(num_total_experts)]) + self.num_total_experts = num_total_experts + + def forward(self, hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + routing_weights = F.softmax(score, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros((1, B, D), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights = torch.zeros((B, self.num_total_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, routing_weights) + padded_weights = padded_weights.reshape(-1, B, self.num_total_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + htorch.core.mark_step() + + for expert_idx in range(self.num_total_experts): + padded_weight = padded_weights[expert_idx] + current_state_static = hidden_states.reshape(-1, D) + w_output = self.w13_list[expert_idx].calc(current_state_static, + expert_idx, w1) + w_output = silu_and_mul(w_output) + w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2) + current_hidden_states_static = w_output * padded_weight + final_hidden_states += current_hidden_states_static + + return final_hidden_states.view(-1, D) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b49bf40d4746e..cf0d5f98f1b01 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,9 +13,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hpu -if is_hpu(): - from vllm.hpu.ops import static_fused_moe - logger = init_logger(__name__) @@ -78,7 +75,8 @@ def apply( ) -> torch.Tensor: return self.forward(x, layer.w13_weight, layer.w2_weight, router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) + use_grouped_topk, num_expert_group, topk_group, + layer) def forward_cuda( self, @@ -91,6 +89,7 @@ def forward_cuda( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe return fused_moe(x, @@ -104,15 +103,25 @@ def forward_cuda( num_expert_group=num_expert_group, topk_group=topk_group) - def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, num_expert_group: Optional[int], - topk_group: Optional[int]): + def forward_hpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + layer: Optional[torch.nn.Module], + ): assert not use_grouped_topk, 'use_grouped_topk must be False on HPU' assert num_expert_group is None, ('num_expert_group is ' 'not supported on HPU') assert topk_group is None, 'topk_group is not supported on HPU' - return static_fused_moe(x, w1, w2, router_logits, top_k) + if layer is not None: + return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -129,6 +138,7 @@ def forward_tpu( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk @@ -140,7 +150,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -191,6 +201,9 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + if is_hpu(): + from vllm.hpu.ops import StaticFusedMOE + self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -245,13 +258,22 @@ def weight_loader(self, param: torch.nn.Parameter, if shard_id == 0: param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] + if is_hpu(): + self.hpu_static_fused_moe.w2_list[expert_id].set_weight( + param_data[expert_id]) else: raise ValueError( f"Shard id must be in [0,1,2] but got {shard_id}") diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index f6718ec2ac9e7..ec0141b61f58f 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -5,6 +5,8 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -52,6 +54,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["INCLinearMethod"]: if isinstance(layer, LinearBase): return INCLinearMethod(self) + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod() return None def get_scaled_act_names(self) -> List[str]: @@ -78,7 +82,7 @@ class INCLinearMethod(LinearMethodBase): 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - + Args: quant_config: The quantization config. """ diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46e..a8b0a7b07ed8e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,7 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in ["fp8", "inc"] and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] From 6a734f4d2b14040b3bbcd8cb9843fac9dfc8318b Mon Sep 17 00:00:00 2001 From: Ilia Taraban Date: Thu, 12 Sep 2024 11:51:05 +0200 Subject: [PATCH 9/9] Fixed ALiBi (#254) Fixed ALiB and [MPT-7B](https://www.databricks.com/blog/mpt-7b) model. Accuracy results comparing to CPU(collected using [EleutherAI](https://github.com/EleutherAI/lm-evaluation-harness)) | Tasks | CPU | HPU | | -------------- | ------ | ------ | | arc_challenge | 0.4224 | 0.4189 | | arc_easy | 0.6974 | 0.6999 | | hellaswag | 0.7603 | 0.7626 | | lambada_openai | 0.7306 | 0.7326 | | mmlu | 0.293 | 0.2925 | | winogrande | 0.6851 | 0.6811 | --- vllm/attention/backends/habana_attn.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 20b0f2bc7630b..56b71a431aca7 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -108,17 +108,10 @@ def __init__( self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.position_bias = None self.alibi_slopes = alibi_slopes if alibi_slopes is not None: - # FIXME(kzawora): Need a general method to set max_seq_len on - # per-model basis. alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) - self.position_bias = _make_alibi_bias(alibi_slopes_tensor, - num_kv_heads, - alibi_slopes_tensor.dtype, - max_seq_len) self.alibi_slopes = alibi_slopes_tensor assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -190,11 +183,13 @@ def forward( assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) else: attn_bias = None