Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 and Llama2 are ExecuTorch compatible #34101

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from parameterized import parameterized

from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
Expand Down Expand Up @@ -916,6 +917,74 @@ def test_compile_static_cache(self):
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)

@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

llama_models = {
"meta-llama/Llama-3.2-1B": [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all "
"observers, regardless of their location, and 2) the laws of physics are the same for all observers"
],
"meta-llama/Llama-3.2-3B": [
"Simply put, the theory of relativity states that 1. the speed of light is constant, and 2. "
"the speed of light is the fastest speed possible"
],
"meta-llama/Llama-2-7b-hf": [
"Simply put, the theory of relativity states that 1) the speed of light is a constant, and 2) "
"the laws of physics are the same for all",
],
}

for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(llama_model_ckp, pad_token="</s>", padding_side="right")
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = LlamaForCausalLM.from_pretrained(
llama_model_ckp,
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)

prompts = ["Simply put, the theory of relativity states that "]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)


@slow
@require_torch_accelerator
Expand Down
Loading