Skip to content

Commit

Permalink
fix no_sync context for deepspeed across all zero types
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 11, 2024
1 parent 37773bd commit 41434e6
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,13 +2513,9 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
if i != len(batch_samples) - 1 and self.accelerator.distributed_type != DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
Expand Down

0 comments on commit 41434e6

Please sign in to comment.