diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 24737931530..efbbc46321c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -268,7 +268,13 @@ def __init__( if model is not None and not self.args.model_parallel: model = model.to(args.device) - self.model = model + if args.deepspeed: + self.model = model.module + self.wrapped_model = model + else: + self.model = model + self.wrapped_model = None + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -703,8 +709,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(model_path) + model = self.wrapped_model if self.wrapped_model else self.model + # Mixed precision training with apex (torch < 1.6) - model = self.model if self.use_apex: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) @@ -729,6 +736,14 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + # for the rest of this function ``model`` is the outside model, whether it was wrapped or not + if model != self.model: + self.wrapped_model = model + + # important: at this point: + # self.model is the Transformers Model + # self.wrapped_model is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc. + # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() @@ -1199,7 +1214,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() elif self.args.deepspeed: - model.module.backward(loss) + # calling on DS engine (wrapped_model == DDP(Deepspeed(PretrainedModule))) + self.wrapped_model.module.backward(loss) else: loss.backward()