Skip to content

Commit

Permalink
Merge branch 'habana_main' into private/jmaksymczuk/fake_hpu_cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaksymczuk authored Sep 10, 2024
2 parents e5cd53a + 5cf8441 commit 4ab0063
Show file tree
Hide file tree
Showing 13 changed files with 375 additions and 438 deletions.
22 changes: 11 additions & 11 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,33 +455,33 @@ Environment variables
- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment
variables configuring ranges of bucketing mechanism
- `{phase}` is either `PROMPT` or `DECODE`
- `{dim}` is either `BS` or `SEQ`
- `{dim}` is either `BS`, `SEQ` or `BLOCK`
- `{param}` is either `MIN`, `STEP` or `MAX`
- Default values:
- Prompt:
- batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)`
- batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`):
`min(max_num_seqs, 64)`
- sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`):
`1024`
`max_model_len`

- Decode:
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1`
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)`
- batch size step (`VLLM_DECODE_BS_BUCKET_STEP`):
`128`
`min(max_num_seqs, 32)`
- batch size max (`VLLM_DECODE_BS_BUCKET_MAX`):
`max_num_seqs`
- sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`):
`2048`
- block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`):
`128`
- block size step
(`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128`
- block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`):
`max(128, (max_num_seqs*max_model_len)/block_size)`

Additionally, there are HPU PyTorch Bridge environment variables
impacting vLLM execution:
Expand Down
14 changes: 7 additions & 7 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,19 +335,19 @@ Environment variables

- Prompt:
- batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)``
- sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len``

- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128``
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``


Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ tqdm
py-cpuinfo
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
fastapi == 0.112.2 # Hardcoding this to workaround issue with new fastapi.
aiohttp
openai
uvicorn[standard]
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_multilora_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _test_llama_multilora(sql_lora_files, tp_size):
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
max_num_seqs=256,
dtype='float32',
tensor_parallel_size=tp_size)
engine = LLMEngine.from_engine_args(engine_args)
Expand Down
136 changes: 40 additions & 96 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,58 +58,14 @@ def copy_blocks(


@dataclass
class HabanaAttentionMetadata(AttentionMetadata, HabanaPagedAttentionMetadata):
"""Metadata for HabanaAttentionbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HabanaAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]

# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|

# Maximum query length in the batch.
max_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[torch.Tensor] = None


class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
Expand Down Expand Up @@ -229,60 +185,48 @@ def forward(

if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads,
self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
output = out.reshape(batch_size, seq_len, hidden_size)
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
# prefix-enabled attention
output = HabanaPagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.seq_lens_tensor,
attn_metadata.context_lens_tensor,
attn_metadata.max_query_len,
self.alibi_slopes,
)
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HabanaPagedAttention.forward_decode(
query, key_cache, value_cache, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, self.kv_cache_dtype,
self.num_kv_heads, self.scale, self.position_bias, k_scale,
v_scale, self.matmul_qk, self.softmax, self.matmul_av,
self.k_cache, self.v_cache)
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
51 changes: 5 additions & 46 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,9 @@
@dataclass
class HabanaPagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down Expand Up @@ -63,42 +56,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
slot_mapping, kv_cache_dtype, is_prompt)

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
alibi_slopes,
kv_cache_dtype,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
)
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)

@staticmethod
def forward_prefix(
Expand Down
8 changes: 4 additions & 4 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import math

import habana_frameworks.torch as htorch
import torch

Expand All @@ -30,8 +32,7 @@ def reshape_and_cache(key,
# lots of padding, or are doing warmup.
# This loop is a workaround for this issue. Please remove it
# once key_cache.index_put_(indices, offsets), key) works.
num_kv_cache_passes = torch.div(num_slots_requested,
num_slots_available).ceil().int().item()
num_kv_cache_passes = math.ceil(num_slots_requested / num_slots_available)
for i in range(num_kv_cache_passes):
start_idx = i * num_slots_available
end_idx = (i + 1) * num_slots_available
Expand All @@ -58,8 +59,7 @@ def prepare_to_cache(cache, slot_mapping):
# lots of padding, or are doing warmup.
# This loop is a workaround for this issue. Please remove it
# once key_cache.index_put_(indices, offsets), key) works.
num_kv_cache_passes = torch.div(num_slots_requested,
num_slots_available).ceil().int().item()
num_kv_cache_passes = math.ceil(num_slots_requested / num_slots_available)

return num_kv_cache_passes, num_slots_available, indices, offsets

Expand Down
Loading

0 comments on commit 4ab0063

Please sign in to comment.