Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 24, 2024
1 parent 865bf2a commit e46a42e
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_stateful_callbacks(self):
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -281,7 +281,7 @@ def test_stateful_callbacks(self):
callbacks=[EarlyStoppingCallback()],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -307,7 +307,7 @@ def test_stateful_mixed_callbacks(self):
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -319,7 +319,7 @@ def test_stateful_mixed_callbacks(self):
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -346,7 +346,7 @@ def test_stateful_duplicate_callbacks(self):
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -358,7 +358,7 @@ def test_stateful_duplicate_callbacks(self):
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -382,7 +382,7 @@ def test_missing_stateful_callback(self):
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
Expand All @@ -391,6 +391,10 @@ def test_missing_stateful_callback(self):

# Create a new trainer with defaults
trainer = self.get_trainer(
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
Expand Down

0 comments on commit e46a42e

Please sign in to comment.