Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attn mask for static cache #30414

Closed

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

Fixes #30400.

It was found that when static cache returns key and values if "length=max_length", the zeros are not masked out. That is why the generation starts returning gibberish at larger max_new_tokens, and is more expressed in "sdpa" attention which just falls back to its internal causal mask.

All slow tests in the models are passing for me locally.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure about this fix, the target_length is probably not computed properly
What's weird about this is that we do take into account the previous zeros here:

causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)

with this anything that is bigger than the target length is mask, no matter the attention mask

@ArthurZucker
Copy link
Collaborator

Also Let's add tests in generation common or integration for llama at least.

@zucchini-nlp
Copy link
Member Author

Ah i did not see it is supposed to be calculated here. The main error for teh code block in linked issue was in this line, where we skip all together the causal mask since it's batch_size=1 and no zeros in attention mask.

That is for sdpa attn but eager also had errors, even though not so big. Let me check with these new info and tests ofc :)

@ArthurZucker
Copy link
Collaborator

Thanks

@zucchini-nlp
Copy link
Member Author

@ArthurZucker I found that indeed th causal mask was being calculated correctly and the gibberish output is only when we use SDPA and batch_size=1. Can be solved by removing the following so that we always rely on our own attn mask, which cannot be simply causal if static cache is used.

I am just not sure, if this line is crucial for Flash Attention 2? 🤔

if self.config._attn_implementation == "sdpa":
            # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
            # in order to dispatch on Flash Attention 2.
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
            ):
                return None

@gante
Copy link
Member

gante commented Apr 23, 2024

@zucchini-nlp #30437 -- this PR also fixes the issue, and I think that is the way to go :D

The fix in this PR adds complexity outside the SDPA path and needs extra computations

@zucchini-nlp
Copy link
Member Author

@gante I see, yeah the cropping of the cache seems a better solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama generation with static cache fails in certain sequence lengths
4 participants