diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0785e03ac54..3231558dec1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3141,14 +3141,30 @@ def evaluation_loop( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - # if eval is called w/o train init deepspeed here + # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.model_wrapped is self.model: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) - model = self.accelerator.prepare(self.model) - self.model_wrapped = self.deepspeed = model model = self._wrap_model(self.model, training=False, dataloader=dataloader) + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: @@ -3736,14 +3752,30 @@ def prediction_loop( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - # if eval is called w/o train init deepspeed here + # if eval is called w/o train, handle model prep here if self.is_deepspeed_enabled and self.model_wrapped is self.model: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) - model = self.accelerator.prepare(self.model) - self.model_wrapped = self.deepspeed = model model = self._wrap_model(self.model, training=False, dataloader=dataloader) + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: