From ac0f36b76cedbb4e6ee024b9a5c0fd79c306cf3e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 6 Jan 2025 17:13:52 -0500 Subject: [PATCH] remove llama fa2 from conftest --- tests/conftest.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f2519cdcf..85e276722 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -120,13 +120,12 @@ def temp_dir(): @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): from transformers import Trainer - from transformers.models.llama.modeling_llama import ( + from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, LlamaAttention, - LlamaFlashAttention2, LlamaForCausalLM, ) - original_fa2_forward = LlamaFlashAttention2.forward + # original_fa2_forward = LlamaFlashAttention2.forward original_llama_attn_forward = LlamaAttention.forward original_llama_forward = LlamaForCausalLM.forward original_trainer_inner_training_loop = ( @@ -136,7 +135,7 @@ def cleanup_monkeypatches(): # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward - LlamaFlashAttention2.forward = original_fa2_forward + # LlamaFlashAttention2.forward = original_fa2_forward LlamaAttention.forward = original_llama_attn_forward LlamaForCausalLM.forward = original_llama_forward Trainer._inner_training_loop = ( # pylint: disable=protected-access @@ -149,7 +148,10 @@ def cleanup_monkeypatches(): ("transformers.models.llama",), ( "transformers.models.llama.modeling_llama", - ["LlamaFlashAttention2", "LlamaAttention"], + [ + # "LlamaFlashAttention2", + "LlamaAttention", + ], ), ("transformers.trainer",), ("transformers", ["Trainer"]),