From 387488bc4c7ab2ae311fb0632b34cab5cbfbab78 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Fri, 15 Dec 2023 14:46:44 -0600 Subject: [PATCH] Add unit test and fix for flash_4 (#108) --- segment_anything_fast/flash_4.py | 10 +++++--- test/test_flash_4.py | 43 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 test/test_flash_4.py diff --git a/segment_anything_fast/flash_4.py b/segment_anything_fast/flash_4.py index fffaef8..0d31df5 100644 --- a/segment_anything_fast/flash_4.py +++ b/segment_anything_fast/flash_4.py @@ -107,14 +107,14 @@ def _fwd_kernel_aligned( v = tl.load(V_block_ptr) # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE) - qk += tl.dot(q, k, out_dtype=OUT_DTYPE) + qk += tl.dot(q, k) #, out_dtype=OUT_DTYPE) # -- compute rel_h[:, None] + rel_w[None, :] bias --- # Bias b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m) * stride_b0m)[:, None] + start_n // BLOCK_N) - qk += (b0 + b1) + qk += ((b0 + b1) * 1.44269504) # -- compute scaling constant --- m_i_new = tl.maximum(m_i, tl.max(qk, 1)) @@ -198,6 +198,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o, P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2] assert P_SEQ == 0 assert rel_h_w.is_contiguous(), str(rel_h_w.stride()) + OUT_DTYPE = tl.float16 if q.dtype == torch.float16 else tl.bfloat16 _fwd_kernel_aligned[grid]( q, k, v, rel_h_w, @@ -212,7 +213,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o, q.shape[1], q.shape[2], P_SEQ, - OUT_DTYPE=tl.float16 if q.dtype == torch.float16 else tl.bfloat16, + OUT_DTYPE=OUT_DTYPE, BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2), B0_NUMEL=rel_h_w.size(-1), BLOCK_M=BLOCK_M, @@ -346,7 +347,8 @@ def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_): def kernel_guards(q_, k_, v_): return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL # vit_b and vit_l - if q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_): + # TODO: This kernel currently does not produce correct results for batch size 1 for this case + if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_): rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1) o = torch.ops.customflash.custom_flash_aligned( q_, k_, v_, rel_h_w, sm_scale) diff --git a/test/test_flash_4.py b/test/test_flash_4.py new file mode 100644 index 0000000..9e20e0c --- /dev/null +++ b/test/test_flash_4.py @@ -0,0 +1,43 @@ +import torch +import itertools +from segment_anything_fast.flash_4 import _attention_rel_h_rel_w + +def test_op(batch, head, seq_len, hidden_dim, dtype): + import math + + sm_scale = 1.0 / math.sqrt(hidden_dim) + device = "cuda" + torch.manual_seed(20) + q = torch.empty( + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device + ).normal_(mean=0.0, std=0.5) + k = torch.empty( + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device + ).normal_(mean=0.0, std=0.5) + v = torch.empty( + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device + ).normal_(mean=0.0, std=0.5) + w = int((seq_len) ** 0.5) + assert w * w == seq_len, "seq_len must be a perfect square" + + rel_h = torch.empty( + (batch, head, seq_len, w, 1), dtype=dtype, device=device + ).normal_(mean=0, std=0.5) + rel_w = torch.empty( + (batch, head, seq_len, 1, w), dtype=dtype, device=device + ).normal_(mean=0, std=0.5) + + tri_out = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) + # reference implementation + attn_bias = (rel_h + rel_w).view( + q.size(0), q.size(1), rel_h.size(2), rel_h.size(3) * rel_w.size(4) + ) + ref_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias + ) + + torch.testing.assert_close(ref_out, tri_out, rtol=1e-3, atol=1e-3) + +for batch, (head, seq_len), dtype in itertools.product([1, 8], [(16, 80), (12, 64)], [torch.float16, torch.bfloat16]): + print(f"batch: {batch} head: {head} seq_len: {seq_len} dtype: {dtype}") + test_op(batch, head, 4096, seq_len, dtype)