Skip to content

Commit

Permalink
Merge pull request #91 from madamczykhabana/flat_block_table
Browse files Browse the repository at this point in the history
Flat block table
  • Loading branch information
bgoldberg-habana authored Jul 7, 2024
2 parents ca1dbf6 + 7aeb218 commit 33b3f41
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 315 deletions.
129 changes: 30 additions & 99 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,59 +59,12 @@ def copy_blocks(
HabanaPagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class HabanaAttentionMetadata(AttentionMetadataPerStage, HabanaPagedAttentionMetadata):
"""Metadata for HabanaAttentionbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
@dataclass(frozen=True)
class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadataPerStage):
"""Metadata for HabanaAttentionbackend."""
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]

# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|

# Maximum query length in the batch.
max_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None


class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
Expand Down Expand Up @@ -202,57 +155,35 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# TODO: move this outside of model
assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!'
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
out = xops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=prefill_meta.attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
softmax_op=self.softmax,
kv_matmul_op=self.kv_matmul,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# prefix-enabled attention
output = HabanaPagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
)
assert prefill_meta.attn_bias is not None, 'attn_bias must be set before calling model.forward!'
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
out = xops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=prefill_meta.attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
softmax_op=self.softmax,
kv_matmul_op=self.kv_matmul,
)
output = out.reshape(batch_size, seq_len, hidden_size)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = HabanaPagedAttention.forward_decode(
query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
self.qk_matmul,
self.softmax,
self.kv_matmul,
self.key_cache.fetch_from_cache,
self.value_cache.fetch_from_cache,
)
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=decode_meta.block_list,
block_mapping=decode_meta.block_mapping,
block_bias=decode_meta.attn_bias,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
kv_matmul_op=self.kv_matmul,
keys_fetch_func=self.key_cache.fetch_from_cache,
values_fetch_func=self.value_cache.fetch_from_cache)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Expand Down
52 changes: 6 additions & 46 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,12 @@
_PARTITION_SIZE = 512


@dataclass
@dataclass(frozen=True)
class HabanaPagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down Expand Up @@ -74,41 +67,8 @@ def write_to_paged_cache(
)

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
qk_op=torch.matmul,
softmax_op=torch.softmax,
kv_op=torch.matmul,
keys_fetch=ops.fetch_from_cache,
values_fetch=ops.fetch_from_cache,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
alibi_slopes,
kv_cache_dtype,
qk_op,
softmax_op,
kv_op,
keys_fetch,
values_fetch,
)
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(
Expand Down
91 changes: 54 additions & 37 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,62 @@ def gelu_fast(output, input):
raise NotImplementedError


def fetch_from_cache(cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]


def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None,
qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None:
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, _, kv_heads, _ = key_cache.shape
min_inf = torch.finfo(query.dtype).min
mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device)
.view(1, -1)
.expand(batch_size, -1)
.ge(context_lens.view(-1, 1))
.view(batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1))
if query_heads != kv_heads:
def batch2block(tensor, block_mapping):
shape = tuple(tensor.shape)
return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:])


def block2batch(tensor, block_mapping):
shape = tuple(tensor.shape)
return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:])


def block_softmax(batch_size, attn, block_mapping):
attn = attn.exp_()
sums = attn.sum(dim=-1).unsqueeze(-1)
sums = block2batch(sums, block_mapping)
sums = batch2block(sums, block_mapping)
attn.div_(sums)
return attn


def flat_pa(query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
scale,
qk_matmul_op,
kv_matmul_op,
keys_fetch_func,
values_fetch_func):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)

query = batch2block(scale * query, block_mapping).unsqueeze(-2)
key = keys_fetch_func(key_cache, block_list).transpose(1, 2)
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)

if kv_heads != q_heads:
block_bias = block_bias.unsqueeze(1)
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)
attn_weights = [qk_matmul_op(query, k) for k in keys]
attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf),
dim=-1)

values = values_fetch_func(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
key = key.transpose(3, 4)
else:
values = [torch.cat(values, dim=-2)]
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [kv_matmul_op(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
return attn_weights.squeeze(-2)
key = key.transpose(2, 3)

attn = qk_matmul_op(query, key) + block_bias
attn = block_softmax(batch_size, attn, block_mapping)
attn = kv_matmul_op(attn, value)
attn = block2batch(attn, block_mapping)
attn = attn.squeeze(-2)
if kv_heads != q_heads:
attn = attn.flatten(1, 2)
return attn


def rms_norm(out, hidden_states, weight, eps):
Expand Down
4 changes: 2 additions & 2 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ def forward(self, input, cache, block_indices, block_offset):
insert_or_update_cache(input, cache, block_indices, block_offset)
return cache

def fetch_from_cache(self, cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
def fetch_from_cache(self, cache, blocks):
return cache.index_select(0, blocks)
Loading

0 comments on commit 33b3f41

Please sign in to comment.