We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
您好,代码这个地方 https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/blob/main/hyvideo/modules/models.py#L978 需要新增下面的一行,才可以保证torch.sdpa的结果和flash varlen的结果完全一致(可以正常出视频), attn_mask[0, total_len:, total_len:] = True ,这样可以在torch2.3也是可以跑通的,可以保证结果和hunyuanvideo原始的版本完全一致,关于这个问题,我在这里也提过: https://huggingface.co/hunyuanvideo-community/HunyuanVideo/discussions/1
下面是我的一个验证: import torch import torch.nn.functional as F
import random import numpy as np
def set_seeds(seed_list, device=None): if isinstance(seed_list, (tuple, list)): seed = sum(seed_list) else: seed = seed_list random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
def flash_attn_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device): from flash_attn import flash_attn_varlen_func
cu_seqlens_q = torch.tensor([0, img_len + effective_condition_sequence_length, img_len + txt_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, img_len + effective_condition_sequence_length, img_len + txt_len], device=device, dtype=torch.int32) max_seqlen_q = img_len + txt_len max_seqlen_k = img_len + txt_len print(f'cu_seqlens_q: {cu_seqlens_q}') print(f'cu_seqlens_k: {cu_seqlens_k}') print(f'max_seqlen_q: {max_seqlen_q}') print(f'max_seqlen_k: {max_seqlen_k}') q = q.view(-1, num_heads, head_dim) k = k.view(-1, num_heads, head_dim) v = v.view(-1, num_heads, head_dim) print(f' flash attn q.shape: {q.shape}') print(f' flash attn k.shape: {k.shape}') print(f' flash attn v.shape: {v.shape}') x = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p= 0.0, ) x = x.view(batch_size, seq_len, num_heads * head_dim) print(f'flash varlen x.shape: {x.shape}') return x
def torch_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device): q = q.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) k = k.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) v = v.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) sequence_length = img_len + txt_len attn_mask = torch.zeros( batch_size, sequence_length, sequence_length, device=device, dtype=torch.bool ) # [batch_size, seq_len, seq_len] print(f'attn_mask: {attn_mask}, attn_mask.shape: {attn_mask.shape}') effective_sequence_length = [img_len + effective_condition_sequence_length] for i in range(batch_size): attn_mask[i, :effective_sequence_length[i], :effective_sequence_length[i]] = True attn_mask[i, effective_sequence_length[i]: , effective_sequence_length[i]:] = True print(f'attn_mask: {attn_mask}') out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) x = out.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim) print(f'torch.sdpa x.shape: {x.shape}') return x
if name == "main": set_seeds(43) device='cuda' dtype=torch.bfloat16 num_heads = 16 head_dim = 64 img_len = 1080 txt_len = 256 batch_size = 1 seq_len = img_len + txt_len effective_condition_sequence_length = 2 q = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype) k = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype) v = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
print(f'attn q.shape: {q.shape}') print(f'attn k.shape: {k.shape}') print(f'attn v.shape: {v.shape}') x1 = flash_attn_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device) x2= torch_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device) print(f'flash attn x1: {x1}') print(f'torch.sdpa x2: {x2}') atol, rtol = 1e-3, 1e-8 # print(f'flash varlen vs torch.sdpa: {torch.allclose(x1, x2, atol=atol, rtol=rtol)}') similarity = torch.cosine_similarity( x1.to("cpu").ravel().double(), x2.to("cpu").ravel().double(), dim=0 ).item() print(f'flash varlen vs torch.sdpa similarity: {similarity}')
The text was updated successfully, but these errors were encountered:
另外我实验发现,将 https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/blob/main/hyvideo/modules/models.py#L978,改为下面这行,对比flash attention varlen出来结果也是一样的(目前diffusers使用也是这样的mask)。 attn_mask[0, :, :total_len] = True
Sorry, something went wrong.
No branches or pull requests
您好,代码这个地方 https://github.com/kijai/ComfyUI-HunyuanVideoWrapper/blob/main/hyvideo/modules/models.py#L978
需要新增下面的一行,才可以保证torch.sdpa的结果和flash varlen的结果完全一致(可以正常出视频),
attn_mask[0, total_len:, total_len:] = True
,这样可以在torch2.3也是可以跑通的,可以保证结果和hunyuanvideo原始的版本完全一致,关于这个问题,我在这里也提过:
https://huggingface.co/hunyuanvideo-community/HunyuanVideo/discussions/1
下面是我的一个验证:
import torch
import torch.nn.functional as F
import random
import numpy as np
def set_seeds(seed_list, device=None):
if isinstance(seed_list, (tuple, list)):
seed = sum(seed_list)
else:
seed = seed_list
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def flash_attn_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device):
from flash_attn import flash_attn_varlen_func
def torch_impl(q, k, v, batch_size, seq_len, num_heads, head_dim, img_len, txt_len, effective_condition_sequence_length, device):
q = q.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
sequence_length = img_len + txt_len
attn_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=device, dtype=torch.bool
) # [batch_size, seq_len, seq_len]
print(f'attn_mask: {attn_mask}, attn_mask.shape: {attn_mask.shape}')
effective_sequence_length = [img_len + effective_condition_sequence_length]
for i in range(batch_size):
attn_mask[i, :effective_sequence_length[i], :effective_sequence_length[i]] = True
attn_mask[i, effective_sequence_length[i]: , effective_sequence_length[i]:] = True
print(f'attn_mask: {attn_mask}')
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = out.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)
print(f'torch.sdpa x.shape: {x.shape}')
return x
if name == "main":
set_seeds(43)
device='cuda'
dtype=torch.bfloat16
num_heads = 16
head_dim = 64
img_len = 1080
txt_len = 256
batch_size = 1
seq_len = img_len + txt_len
effective_condition_sequence_length = 2
q = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
k = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
v = torch.randn([batch_size, seq_len, num_heads, head_dim], device=device, dtype=dtype)
The text was updated successfully, but these errors were encountered: