Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why enhanced_scores is multiplied to hidden_states #22

Open
xUhEngwAng opened this issue Dec 27, 2024 · 2 comments
Open

Why enhanced_scores is multiplied to hidden_states #22

xUhEngwAng opened this issue Dec 27, 2024 · 2 comments

Comments

@xUhEngwAng
Copy link

xUhEngwAng commented Dec 27, 2024

Hi, thanks for your awesome work! I have some questions about the implementation details.

It seems enhanced_scores is a scalar value regarding the temporal consistency of current layer. So hidden_states with higher temporal consistency is enhanced as a whole? I'm not sure whether it's reasonable to do this.

However, in your blog post, the output from enhance module (and temporal attention) is added to hidden_states, which is somewhat confusing.

Is my understanding correct? Cound u provide more details about the implementation and maybe some ablation results?

@yangluo7
Copy link
Contributor

yangluo7 commented Jan 1, 2025

thx for your question. Yes, the enhance score calculated in Enhance Block is a scalar value that is used to multiply by the output of temporal attention block. Later, the enhanced temporal attention output will be added to the hidden_states in the form of residual connection in DiT. More details can be found in our blog: https://oahzxl.github.io/Enhance_A_Video/

@xUhEngwAng
Copy link
Author

thx for your question. Yes, the enhance score calculated in Enhance Block is a scalar value that is used to multiply by the output of temporal attention block. Later, the enhanced temporal attention output will be added to the hidden_states in the form of residual connection in DiT. More details can be found in our blog: https://oahzxl.github.io/Enhance_A_Video/

Sry, I do not notice the part of adding enhanced temporal attention output to hidden_states by residual connection in your code.

if is_enhance_enabled():
hidden_states = hidden_states * enhance_scores

In models/cogvideox.py line 145, hidden_states is directly multiplied by enhance_scores, with the latter computed by multiplying the average of temporal attention map (w.o. diagonal, which is a scalar) by a preset enhance_weight.

# Calculate mean for each token's attention matrix
# Number of off-diagonal elements per matrix is n*n - n
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
enhance_scores = enhance_scores.clamp(min=1)
return enhance_scores

Could you please point me out the exact snippets of code regarding the residual connection part?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants