Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk prefill cache writes, remove div_i32 from insert_or_update_cache #289

Merged
Merged
3 changes: 1 addition & 2 deletions requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ ray == 2.32.0
triton
pandas
tabulate

vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@940fdb7
17 changes: 9 additions & 8 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension import cache_ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -166,20 +165,22 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HabanaPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
num_kv_cache_passes, num_slots_available, indices, offsets = \
cache_ops.prepare_to_cache(key_cache,
attn_metadata.slot_mapping)
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class HabanaPagedAttentionMetadata:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down
22 changes: 21 additions & 1 deletion vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ def pad_list(list, k, v):
return list + [v] * padding


def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
return indices, offsets


class HpuModelAdapter():

def __init__(self, model, block_size, dtype, enforce_eager):
Expand Down Expand Up @@ -890,11 +901,15 @@ def _prepare_prompt(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=None,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=real_num_seqs,
Expand Down Expand Up @@ -1044,11 +1059,15 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=block_mapping,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=None,
num_prefills=0,
Expand Down Expand Up @@ -1266,7 +1285,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt'
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets'
])
return attention_metadata

Expand Down
Loading