Skip to content

Commit

Permalink
Add unit test and fix for flash_4 (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Dec 15, 2023
1 parent 92f7296 commit 387488b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
10 changes: 6 additions & 4 deletions segment_anything_fast/flash_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def _fwd_kernel_aligned(
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE)
qk += tl.dot(q, k, out_dtype=OUT_DTYPE)
qk += tl.dot(q, k) #, out_dtype=OUT_DTYPE)

# -- compute rel_h[:, None] + rel_w[None, :] bias ---

# Bias
b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)
* stride_b0m)[:, None] + start_n // BLOCK_N)
qk += (b0 + b1)
qk += ((b0 + b1) * 1.44269504)

# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
Expand Down Expand Up @@ -198,6 +198,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
assert P_SEQ == 0
assert rel_h_w.is_contiguous(), str(rel_h_w.stride())
OUT_DTYPE = tl.float16 if q.dtype == torch.float16 else tl.bfloat16
_fwd_kernel_aligned[grid](
q, k, v,
rel_h_w,
Expand All @@ -212,7 +213,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,
q.shape[1],
q.shape[2],
P_SEQ,
OUT_DTYPE=tl.float16 if q.dtype == torch.float16 else tl.bfloat16,
OUT_DTYPE=OUT_DTYPE,
BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2),
B0_NUMEL=rel_h_w.size(-1),
BLOCK_M=BLOCK_M,
Expand Down Expand Up @@ -346,7 +347,8 @@ def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
def kernel_guards(q_, k_, v_):
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
# vit_b and vit_l
if q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
# TODO: This kernel currently does not produce correct results for batch size 1 for this case
if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q_, k_, v_, rel_h_w, sm_scale)
Expand Down
43 changes: 43 additions & 0 deletions test/test_flash_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import itertools
from segment_anything_fast.flash_4 import _attention_rel_h_rel_w

def test_op(batch, head, seq_len, hidden_dim, dtype):
import math

sm_scale = 1.0 / math.sqrt(hidden_dim)
device = "cuda"
torch.manual_seed(20)
q = torch.empty(
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
).normal_(mean=0.0, std=0.5)
k = torch.empty(
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
).normal_(mean=0.0, std=0.5)
v = torch.empty(
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
).normal_(mean=0.0, std=0.5)
w = int((seq_len) ** 0.5)
assert w * w == seq_len, "seq_len must be a perfect square"

rel_h = torch.empty(
(batch, head, seq_len, w, 1), dtype=dtype, device=device
).normal_(mean=0, std=0.5)
rel_w = torch.empty(
(batch, head, seq_len, 1, w), dtype=dtype, device=device
).normal_(mean=0, std=0.5)

tri_out = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
# reference implementation
attn_bias = (rel_h + rel_w).view(
q.size(0), q.size(1), rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias
)

torch.testing.assert_close(ref_out, tri_out, rtol=1e-3, atol=1e-3)

for batch, (head, seq_len), dtype in itertools.product([1, 8], [(16, 80), (12, 64)], [torch.float16, torch.bfloat16]):
print(f"batch: {batch} head: {head} seq_len: {seq_len} dtype: {dtype}")
test_op(batch, head, 4096, seq_len, dtype)

0 comments on commit 387488b

Please sign in to comment.