Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Based transformers version needed for modifying models/modeling_llama.py #31

Open
yeahjack opened this issue May 14, 2024 · 3 comments
Open

Comments

@yeahjack
Copy link

I noticed that models/modeling_llama.py is based on the code from here However, I found that your implementation does not support Flash Attention 2. Therefore, I would like to request further modifications. To facilitate this process, could you please specify the exact version of transformers that your implementation is based on? This will make it easier for me to perform a comparison.

@RenShuhuai-Andy
Copy link
Owner

Hi, our current code is based on transformers==4.34.0 (https://github.com/RenShuhuai-Andy/TimeChat/blob/master/environment.yml#L273).

To use flash attn 2, you can upgrade transformers and use the following code:

from transformers import AutoTokenizer, AutoModelForCausalLM

self.llama_model = AutoModelForCausalLM.from_pretrained(
                    llama_model,
                    torch_dtype="auto",
                    attn_implementation="flash_attention_2",
                    low_cpu_mem_usage=True
                )

@yeahjack
Copy link
Author

yeahjack commented May 14, 2024

Hi, I tried to use your suggestion, and when using the code from demo.ipynb to generate texts with flash-attn-2 and 8bit(with your low resource mode on), it alerts me that RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16. I tried to load the LLaMA model in 8bit only and it functions well. Do you have any suggestions?

My test

I search on the Internet and found here and used the method with torch.cuda.amp.autocast():, and it alerts RuntimeError: query and key must have the same dtype, hope it helps.

Thank you very much!

@yeahjack
Copy link
Author

It seems that removing torch_dtype="torch.bloat16" could work, but I am not sure it is the right solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants