Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于 scaled_dot_product_attention 替换 flash_attn_varlen_func 问题 #284

Open
neonhuang opened this issue Jan 15, 2025 · 1 comment

Comments

@neonhuang
Copy link

neonhuang commented Jan 15, 2025

您好,代码这个地方 https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/blob/main/hyvideo/modules/models.py#L978
需要新增下面的一行,才可以保证torch.sdpa的结果和flash varlen的结果完全一致(可以正常出视频),
attn_mask[0, total_len:, total_len:] = True
,这样可以在torch2.3也是可以跑通的,可以保证结果和hunyuanvideo原始的版本完全一致,关于这个问题,我在这里也提过:
https://huggingface.co/hunyuanvideo-community/HunyuanVideo/discussions/1

下面是我的一个验证:
import torch
import torch.nn.functional as F

import random
import numpy as np

def set_seeds(seed_list, device=None):
if isinstance(seed_list, (tuple, list)):
seed = sum(seed_list)
else:
seed = seed_list
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def flash_attn_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device):
from flash_attn import flash_attn_varlen_func

cu_seqlens_q = torch.tensor([0, img_len + effective_condition_sequence_length, img_len + txt_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, img_len + effective_condition_sequence_length, img_len + txt_len], device=device, dtype=torch.int32)
max_seqlen_q = img_len + txt_len
max_seqlen_k = img_len + txt_len
print(f'cu_seqlens_q: {cu_seqlens_q}')
print(f'cu_seqlens_k: {cu_seqlens_k}')
print(f'max_seqlen_q: {max_seqlen_q}')
print(f'max_seqlen_k: {max_seqlen_k}')
    
q = q.view(-1, num_heads, head_dim)
k = k.view(-1, num_heads, head_dim)
v = v.view(-1, num_heads, head_dim)

print(f' flash attn q.shape: {q.shape}')
print(f' flash attn k.shape: {k.shape}')
print(f' flash attn v.shape: {v.shape}')
x = flash_attn_varlen_func(
    q,
    k, 
    v,
    cu_seqlens_q=cu_seqlens_q,
    cu_seqlens_k=cu_seqlens_k,
    max_seqlen_q=max_seqlen_q,
    max_seqlen_k=max_seqlen_k,
    dropout_p= 0.0,
)
x = x.view(batch_size, seq_len, num_heads * head_dim)
print(f'flash varlen x.shape: {x.shape}')
return x

def torch_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device):
q = q.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
sequence_length = img_len + txt_len
attn_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=device, dtype=torch.bool
) # [batch_size, seq_len, seq_len]
print(f'attn_mask: {attn_mask}, attn_mask.shape: {attn_mask.shape}')
effective_sequence_length = [img_len + effective_condition_sequence_length]
for i in range(batch_size):
attn_mask[i, :effective_sequence_length[i], :effective_sequence_length[i]] = True
attn_mask[i, effective_sequence_length[i]: , effective_sequence_length[i]:] = True
print(f'attn_mask: {attn_mask}')
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = out.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)
print(f'torch.sdpa x.shape: {x.shape}')
return x

if name == "main":
set_seeds(43)
device='cuda'
dtype=torch.bfloat16
num_heads = 16
head_dim = 64
img_len = 1080
txt_len = 256
batch_size = 1
seq_len = img_len + txt_len
effective_condition_sequence_length = 2
q = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
k = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
v = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)

print(f'attn q.shape: {q.shape}')
print(f'attn k.shape: {k.shape}')
print(f'attn v.shape: {v.shape}')

x1 = flash_attn_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device)
x2= torch_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device)

print(f'flash attn x1: {x1}')
print(f'torch.sdpa x2: {x2}')

atol, rtol = 1e-3, 1e-8
# print(f'flash varlen vs torch.sdpa: {torch.allclose(x1, x2, atol=atol, rtol=rtol)}')
similarity = torch.cosine_similarity(
    x1.to("cpu").ravel().double(),
    x2.to("cpu").ravel().double(),
    dim=0
    ).item()
print(f'flash varlen vs torch.sdpa similarity: {similarity}')
@neonhuang neonhuang changed the title 关于torch.sdpa替换 flash varlen问题 关于torch.sdpa替换 flash_attn_varlen_func 问题 Jan 15, 2025
@neonhuang neonhuang changed the title 关于torch.sdpa替换 flash_attn_varlen_func 问题 关于 scaled_dot_product_attention 替换 flash_attn_varlen_func 问题 Jan 15, 2025
@neonhuang
Copy link
Author

neonhuang commented Jan 17, 2025

另外我实验发现,将 https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/blob/main/hyvideo/modules/models.py#L978,改为下面这行,对比flash attention varlen出来结果也是一样的(目前diffusers使用也是这样的mask)。
attn_mask[0, :, :total_len] = True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant