From e01a61aeab87c387ffc700ab7a40f765a9b01cfd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Nov 2024 16:28:06 -0500 Subject: [PATCH] FSDP grad accum fix (#34645) * add gradient accumulation steps tests for fsdp * invert no_sync context to fix training for fsdp --- src/transformers/trainer.py | 2 +- tests/fsdp/test_fsdp.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a46bbf5445a360..45f0026154839d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2474,7 +2474,7 @@ def _inner_training_loop( # We explicitly want to avoid relying on `accelerator.accumulate` for generation training context = ( functools.partial(self.accelerator.no_sync, model=model) - if i == len(batch_samples) - 1 + if i != len(batch_samples) - 1 else contextlib.nullcontext ) with context(): diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 7e14cc8c9e6fc9..74a3bfe04b7506 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -224,6 +224,18 @@ def test_basic_run(self, sharding_strategy, dtype): cmd = launcher + script + args + fsdp_args execute_subprocess_async(cmd, env=self.get_env()) + @parameterized.expand(params, name_func=_parameterized_custom_name_func) + @require_torch_multi_accelerator + @slow + def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype): + launcher = get_launcher(distributed=True, use_accelerate=False) + output_dir = self.get_auto_remove_tmp_dir() + args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"] + fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"] + script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"] + cmd = launcher + script + args + fsdp_args + execute_subprocess_async(cmd, env=self.get_env()) + @parameterized.expand(dtypes) @require_torch_multi_accelerator @slow