Skip to content

Commit

Permalink
introduce the concept of self.wrapped_model
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Dec 23, 2020
1 parent 9cc3b63 commit 1510444
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,13 @@ def __init__(
if model is not None and not self.args.model_parallel:
model = model.to(args.device)

self.model = model
if args.deepspeed:
self.model = model.module
self.wrapped_model = model
else:
self.model = model
self.wrapped_model = None

default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
Expand Down Expand Up @@ -703,8 +709,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(model_path)

model = self.wrapped_model if self.wrapped_model else self.model

This comment has been minimized.

Copy link
@sgugger

sgugger Dec 23, 2020

Collaborator

Nit: we like to explicitly test the is None:

model = self.model if self.wrapped_model is None else self.wrapped_model

# Mixed precision training with apex (torch < 1.6)
model = self.model
if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

Expand All @@ -729,6 +736,14 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021

# for the rest of this function ``model`` is the outside model, whether it was wrapped or not
if model != self.model:
self.wrapped_model = model

# important: at this point:
# self.model is the Transformers Model
# self.wrapped_model is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.

# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
Expand Down Expand Up @@ -1199,7 +1214,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.args.deepspeed:
model.module.backward(loss)
# calling on DS engine (wrapped_model == DDP(Deepspeed(PretrainedModule)))
self.wrapped_model.module.backward(loss)
else:
loss.backward()

Expand Down

0 comments on commit 1510444

Please sign in to comment.