Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jun 26, 2022
1 parent cc5c061 commit 93f48da
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,14 +649,14 @@ def test_number_of_steps_in_training_with_ipex(self):
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)

# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))

# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(
Expand Down

0 comments on commit 93f48da

Please sign in to comment.