diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ce58d6adebdf5a..9b1aa4001626e7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -649,14 +649,14 @@ def test_number_of_steps_in_training_with_ipex(self): # Regular training has n_epochs * len(train_dl) steps trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True) train_output = trainer.train() - self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size) + self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size) # Check passing num_train_epochs works (and a float version too): trainer = get_regression_trainer( learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True ) train_output = trainer.train() - self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size)) + self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size)) # If we pass a max_steps, num_train_epochs is ignored trainer = get_regression_trainer(