Skip to content

Commit

Permalink
adapt _init_fsdp to fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Oct 16, 2024
1 parent 4827a39 commit 4a84f0f
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _get_fsdp_ckpt_kwargs():
return {}


def _init_fsdp(model, accelerator, device):
def _init_fsdp(model, device):
"""
Initialize Fully Sharded Data Parallel (FSDP) for the model.
Expand All @@ -283,13 +283,8 @@ def _init_fsdp(model, accelerator, device):
Args:
model: The model to initialize with FSDP.
accelerator: The Accelerator object.
device: The device to run the model on.
Returns:
The initialized FSDP model.
"""
model = accelerator.prepare(model)
model.train()
with torch.no_grad():
# Run a forward pass with dummy inputs to initialize FSDP
Expand All @@ -303,7 +298,6 @@ def _init_fsdp(model, accelerator, device):
if name != "self"
}
_ = model(**dummy_input)
return model


if TYPE_CHECKING:
Expand Down Expand Up @@ -635,9 +629,6 @@ def __init__(
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)

if self.is_fsdp_enabled:
self.model = _init_fsdp(self.model, self.accelerator, self.args.device)

if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):
Expand Down Expand Up @@ -2285,6 +2276,7 @@ def _inner_training_loop(
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.is_fsdp_enabled:
_init_fsdp(self.model, self.args.device)
self.model = self.model_wrapped = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
Expand Down

0 comments on commit 4a84f0f

Please sign in to comment.