Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Upstream sync 2024 03 14 #127

Merged
merged 114 commits into from
Mar 15, 2024
Merged
Changes from 1 commit
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
d7f3964
Update comment (#2934)
ronensc Feb 22, 2024
5574081
Added early stopping to completion APIs (#2939)
Maxusmusti Feb 22, 2024
344020c
Migrate MistralForCausalLM to LlamaForCausalLM (#2868)
esmeetu Feb 22, 2024
95529e3
Use Llama RMSNorm custom op for Gemma (#2974)
WoosukKwon Feb 22, 2024
93dc5a2
chore(vllm): codespell for spell checking (#2820)
mspronesti Feb 22, 2024
fd5dcc5
Optimize GeGLU layer in Gemma (#2975)
WoosukKwon Feb 22, 2024
c530e2c
[FIX] Fix a bug in initializing Yarn RoPE (#2983)
44670 Feb 22, 2024
6f32cdd
Remove Flash Attention in test env (#2982)
WoosukKwon Feb 22, 2024
4caf704
Include tokens from prompt phase in `counter_generation_tokens` (#2802)
ronensc Feb 22, 2024
57f0449
Fix nvcc not found in vlm-openai image (#2781)
zhaoyang-star Feb 22, 2024
f7c1234
[Fix] Fissertion on YaRN model len (#2984)
WoosukKwon Feb 23, 2024
ef978fe
Port metrics from `aioprometheus` to `prometheus_client` (#2730)
hmellor Feb 25, 2024
70f3e8e
Add LogProbs for Chat Completions in OpenAI (#2918)
jlcmoore Feb 26, 2024
cfc15a1
Optimize Triton MoE Kernel (#2979)
pcmoritz Feb 26, 2024
d6e4a13
[Minor] Remove gather_cached_kv kernel (#3043)
WoosukKwon Feb 26, 2024
d9f726c
[Minor] Remove unused config files (#3039)
esmeetu Feb 27, 2024
c1c0d00
Don't use cupy when `enforce_eager=True` (#3037)
esmeetu Feb 27, 2024
4dd6416
Fix stablelm (#3038)
esmeetu Feb 27, 2024
48a8f4a
Support Orion model (#2539)
dachengai Feb 27, 2024
2410e32
fix `get_ip` error in pure ipv6 environment (#2931)
Jingru Feb 27, 2024
4bd18ec
[Minor] Fix type annotation in fused moe (#3045)
WoosukKwon Feb 27, 2024
e0ade06
Support logit bias for OpenAI API (#3027)
dylanwhawk Feb 27, 2024
8b430d7
[Minor] Fix StableLMEpochForCausalLM -> StableLmForCausalLM (#3046)
WoosukKwon Feb 27, 2024
71bcaf9
Enable GQA support in the prefix prefill kernels (#3007)
sighingnow Feb 27, 2024
a868310
multi-lora documentation fix (#3064)
ElefHead Feb 28, 2024
e46fa5d
Restrict prometheus_client >= 0.18.0 to prevent errors when importing…
AllenDou Feb 28, 2024
3b7178c
[Neuron] Support inference with transformers-neuronx (#2569)
liangfu Feb 28, 2024
929b4f2
Add LoRA support for Gemma (#3050)
WoosukKwon Feb 28, 2024
01a5d18
Add Support for 2/3/8-bit GPTQ Quantization Models (#2330)
chu-tianxiang Feb 29, 2024
a6d471c
Fix: `AttributeError` in OpenAI-compatible server (#3018)
jaywonchung Feb 29, 2024
9289e57
add cache_config's info to prometheus metrics. (#3100)
AllenDou Feb 29, 2024
bfdcfa6
Support starcoder2 architecture (#3089)
sh0416 Feb 29, 2024
2c08ff2
Fix building from source on WSL (#3112)
aliencaocao Feb 29, 2024
29a8d6a
[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#…
njhill Feb 29, 2024
703e42e
Add guided decoding for OpenAI API server (#2819)
felixzhu555 Feb 29, 2024
54d3544
Fix: Output text is always truncated in some models (#3016)
HyperdriveHustle Mar 1, 2024
27ca23d
Remove exclude_unset in streaming response (#3143)
sh0416 Mar 1, 2024
49d849b
docs: Add tutorial on deploying vLLM model with KServe (#2586)
terrytangyuan Mar 1, 2024
90fbf12
fix relative import path of protocol.py (#3134)
Huarong Mar 1, 2024
c0c2335
Integrate Marlin Kernels for Int4 GPTQ inference (#2497)
robertgshaw2-neuralmagic Mar 1, 2024
82091b8
Bump up to v0.3.3 (#3129)
WoosukKwon Mar 1, 2024
29e70e3
allow user chose log level by --log-level instead of fixed 'info'. (#…
AllenDou Mar 1, 2024
baee28c
Reorder kv dtype check to avoid nvcc not found error on AMD platform …
cloudhan Mar 2, 2024
ce4f5a2
Add Automatic Prefix Caching (#2762)
SageMoore Mar 2, 2024
d65fac2
Add vLLM version info to logs and openai API server (#3161)
jasonacox Mar 3, 2024
996d095
[FIX] Fix styles in automatic prefix caching & add a automatic prefix…
zhuohan123 Mar 3, 2024
17c3103
Make it easy to profile workers with nsight (#3162)
pcmoritz Mar 4, 2024
d0fae88
[DOC] add setup document to support neuron backend (#2777)
liangfu Mar 4, 2024
901cf4c
[Minor Fix] Remove unused code in benchmark_prefix_caching.py (#3171)
gty111 Mar 4, 2024
27a7b07
Add document for vllm paged attention kernel. (#2978)
pian13131 Mar 4, 2024
9cbc7e5
enable --gpu-memory-utilization in benchmark_throughput.py (#3175)
AllenDou Mar 4, 2024
76e8a70
[Minor fix] The domain dns.google may cause a socket.gaierror excepti…
ttbachyinsda Mar 4, 2024
22de452
Push logprob generation to LLMEngine (#3065)
Yard1 Mar 4, 2024
ff578ca
Add health check, make async Engine more robust (#3015)
Yard1 Mar 4, 2024
9a4548b
Fix the openai benchmarking requests to work with latest OpenAI apis …
wangchen615 Mar 4, 2024
05af6da
[ROCm] enable cupy in order to enable cudagraph mode for AMD GPUs (#…
hongxiayang Mar 5, 2024
8999ec3
Store `eos_token_id` in `Sequence` for easy access (#3166)
njhill Mar 5, 2024
2efce05
[Fix] Avoid pickling entire LLMEngine for Ray workers (#3207)
njhill Mar 6, 2024
24aecf4
[Tests] Add block manager and scheduler tests (#3108)
rkooo567 Mar 6, 2024
a33ce60
[Testing] Fix core tests (#3224)
cadedaniel Mar 6, 2024
4cb3b92
Add tqdm `dynamic_ncols=True` (#3242)
chujiezheng Mar 6, 2024
d3c04b6
Add GPTQ support for Gemma (#3200)
TechxGenus Mar 7, 2024
cbf4c05
Update requirements-dev.txt to include package for benchmarking scrip…
wangchen615 Mar 7, 2024
2daf23a
Separate attention backends (#3005)
WoosukKwon Mar 7, 2024
385da2d
Measure model memory usage (#3120)
mgoin Mar 7, 2024
8cbba46
Possible fix for conflict between Automated Prefix Caching (#2762) an…
jacobthebanana Mar 7, 2024
b35cc93
Fix auto prefix bug (#3239)
ElizaWszola Mar 8, 2024
d2339d6
Connect engine healthcheck to openai server (#3260)
njhill Mar 8, 2024
c59e120
Feature add lora support for Qwen2 (#3177)
whyiug Mar 8, 2024
1ece1ae
[Minor Fix] Fix comments in benchmark_serving (#3252)
gty111 Mar 8, 2024
99c3cfb
[Docs] Fix Unmocked Imports (#3275)
ywang96 Mar 8, 2024
1cb0cc2
[FIX] Make `flash_attn` optional (#3269)
WoosukKwon Mar 8, 2024
c2c5e09
Move model filelocks from `/tmp/` to `~/.cache/vllm/locks/` dir (#3241)
mgoin Mar 8, 2024
f48c679
[FIX] Fix prefix test error on main (#3286)
zhuohan123 Mar 9, 2024
8437bae
[Speculative decoding 3/9] Worker which speculates, scores, and appli…
cadedaniel Mar 9, 2024
0bba88d
Enhance lora tests with more layer and rank variations (#3243)
tterrysun Mar 10, 2024
e4a28e5
[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUD…
dllehr-amd Mar 10, 2024
9e8744a
[BugFix] Fix get tokenizer when using ray (#3301)
esmeetu Mar 11, 2024
4b59f00
[Fix] Fix best_of behavior when n=1 (#3298)
njhill Mar 11, 2024
2f8844b
Re-enable the 80 char line width limit (#3305)
zhuohan123 Mar 11, 2024
657061f
[docs] Add LoRA support information for models (#3299)
pcmoritz Mar 11, 2024
4c92270
Add distributed model executor abstraction (#3191)
zhuohan123 Mar 11, 2024
c9415c1
[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)
kliuae Mar 11, 2024
654865e
Support Mistral Model Inference with transformers-neuronx (#3153)
DAIZHENWEI Mar 11, 2024
b0925b3
docs: Add BentoML deployment doc (#3336)
Sherlock113 Mar 12, 2024
49a3c86
Fixes #1556 double free (#3347)
br3no Mar 13, 2024
602358f
Add kernel for GeGLU with approximate GELU (#3337)
WoosukKwon Mar 13, 2024
b167109
[Fix] Fix quantization="gptq" when using Marlin (#3319)
DreamTeamWangbowen Mar 13, 2024
e221910
add hf_transfer to requirements.txt (#3031)
RonanKMcGovern Mar 13, 2024
ba8dc95
[Minor] Fix bias in if to remove ambiguity (#3259)
hliuca Mar 13, 2024
739c350
[Minor Fix] Use cupy-cuda11x in CUDA 11.8 build (#3256)
chenxu2048 Mar 13, 2024
ae0ccb4
Add missing kernel for CodeLlama-34B on A/H100 (no tensor parallelism…
orsharir Mar 13, 2024
7e9bd08
Add batched RoPE kernel (#3095)
tterrysun Mar 13, 2024
c33afd8
Fix lint (#3388)
Yard1 Mar 13, 2024
eeab52a
[FIX] Simpler fix for async engine running on ray (#3371)
zhuohan123 Mar 13, 2024
81653d9
[Hotfix] [Debug] test_openai_server.py::test_guided_regex_completion …
simon-mo Mar 14, 2024
a37415c
allow user to chose which vllm's merics to display in grafana (#3393)
AllenDou Mar 14, 2024
8fe8386
[Kernel] change benchmark script so that result can be directly used;…
youkaichao Mar 14, 2024
06ec486
Install `flash_attn` in Docker image (#3396)
tdoublep Mar 14, 2024
c17ca8e
Add args for mTLS support (#3410)
declark1 Mar 14, 2024
dfc7740
[issue templates] add some issue templates (#3412)
youkaichao Mar 14, 2024
54be8a0
Fix assertion failure in Qwen 1.5 with prefix caching enabled (#3373)
chenxu2048 Mar 14, 2024
87ad0cb
Merge branch 'upstream-main' into upstream-sync-2024-03-14
robertgshaw2-neuralmagic Mar 14, 2024
4518f5a
format
robertgshaw2-neuralmagic Mar 14, 2024
5bc7a73
formating
robertgshaw2-neuralmagic Mar 14, 2024
6f60731
ruff
robertgshaw2-neuralmagic Mar 14, 2024
5ba2ee1
ruff again
robertgshaw2-neuralmagic Mar 14, 2024
d342426
yapf
robertgshaw2-neuralmagic Mar 14, 2024
e283528
finalized ruff
robertgshaw2-neuralmagic Mar 15, 2024
c5633f2
yapf after ruff :)
robertgshaw2-neuralmagic Mar 15, 2024
1271e3c
yapf after ruff :)
robertgshaw2-neuralmagic Mar 15, 2024
c47bd6b
fixed tests post update
robertgshaw2-neuralmagic Mar 15, 2024
b9c3578
missed one test
robertgshaw2-neuralmagic Mar 15, 2024
1e36b51
Update test-pipeline.yaml
robertgshaw2-neuralmagic Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Enable GQA support in the prefix prefill kernels (vllm-project#3007)
Signed-off-by: Tao He <sighingnow@gmail.com>
sighingnow authored Feb 27, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 71bcaf99e2cb2c677bf3a9addb9e8039cbcab22a
61 changes: 42 additions & 19 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,8 @@
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

NUM_HEADS = [12]
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
HEAD_SIZES = [128]
DTYPES = [torch.float16]
CUDA_DEVICES = [
@@ -17,12 +18,14 @@


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
device: str,
@@ -41,28 +44,29 @@ def test_contexted_kv_attention(
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv

num_tokens = sum(subquery_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

k_cache = torch.zeros(cache_size,
block_size,
num_heads,
num_kv_heads,
head_size,
dtype=dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_heads,
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
@@ -93,19 +97,21 @@ def test_contexted_kv_attention(
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_heads,
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()

# Warm up the Triton kernel by calling it once before actually measuring generation time
@@ -123,12 +129,29 @@ def test_contexted_kv_attention(

attn_op = xops.fmha.cutlass.FwOp()

if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)

attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
@@ -137,9 +160,9 @@ def test_contexted_kv_attention(
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
@@ -148,5 +171,5 @@ def test_contexted_kv_attention(
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0)
output_ref = output_ref.squeeze(0, 2)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
34 changes: 18 additions & 16 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
@@ -137,25 +137,27 @@ def forward(
)

if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
39 changes: 27 additions & 12 deletions vllm/model_executor/layers/triton_kernel/prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ def _fwd_kernel(
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
@@ -53,6 +54,8 @@ def _fwd_kernel(
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -85,13 +88,14 @@ def _fwd_kernel(
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
@@ -131,9 +135,9 @@ def _fwd_kernel(
l_i = l_i_new
m_i = m_i_new

off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
@@ -232,6 +236,7 @@ def _fwd_kernel_flash_attn_v2(
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
@@ -240,6 +245,8 @@ def _fwd_kernel_flash_attn_v2(
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -272,13 +279,14 @@ def _fwd_kernel_flash_attn_v2(
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
@@ -317,9 +325,9 @@ def _fwd_kernel_flash_attn_v2(
l_i = l_i_new
m_i = m_i_new

off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
@@ -420,6 +428,7 @@ def _fwd_kernel_alibi(
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
@@ -429,6 +438,8 @@ def _fwd_kernel_alibi(
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
@@ -468,13 +479,14 @@ def _fwd_kernel_alibi(
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
@@ -522,9 +534,9 @@ def _fwd_kernel_alibi(
l_i = l_i_new
m_i = m_i_new

off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
@@ -628,6 +640,7 @@ def context_attention_fwd(q,

sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]

grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

@@ -674,6 +687,7 @@ def context_attention_fwd(q,
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
@@ -721,6 +735,7 @@ def context_attention_fwd(q,
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,