-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma2 and flash-attention #32188
Gemma2 and flash-attention #32188
Changes from 3 commits
fc97eab
cad10d1
b4ca3ca
9c0f447
dc9266f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,11 @@ | |
|
||
import unittest | ||
|
||
from pytest import mark | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline | ||
from transformers.testing_utils import ( | ||
require_flash_attn, | ||
require_read_token, | ||
require_torch, | ||
require_torch_gpu, | ||
|
@@ -161,3 +164,29 @@ def test_model_9b_pipeline_bf16(self): | |
|
||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) | ||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) | ||
|
||
@require_read_token | ||
@require_flash_attn | ||
@require_torch_gpu | ||
@mark.flash_attn_test | ||
@slow | ||
def test_model_9b_flash_attn(self): | ||
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context | ||
# NOTE: the quality is a lot better whan fp16 is used, and worse for bf16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate more on this? Specifically, does using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this was true until I discovered a bug I made. It doesn't matter much which precision is used, removed the comment |
||
model_id = "google/gemma-2-9b" | ||
EXPECTED_TEXTS = [ | ||
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few', | ||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the" | ||
] # fmt: skip | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, attn_implementation="flash_attention_2", torch_dtype="float16" | ||
).to(torch_device) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) | ||
|
||
output = model.generate(**inputs, max_new_tokens=100, do_sample=False) | ||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False) | ||
print(output_text) | ||
|
||
self.assertEqual(output_text, EXPECTED_TEXTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yeah static cache does not support that because it's input dependent, but great catch