Skip to content

Commit

Permalink
unblock legacy checkpoints (#15798)
Browse files Browse the repository at this point in the history
* fixing legacy checkpoints
* Apply suggestions from code review

Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
Borda and akihironitta authored Dec 2, 2022
1 parent 993bd67 commit fee52f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/legacy/simple_classif_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def main_train(dir_path, max_epochs: int = 20):
model = ClassificationModel()
trainer.fit(model, datamodule=dm)
res = trainer.test(model, datamodule=dm)
assert res[0]["test_loss"] <= 0.7
assert res[0]["test_acc"] >= 0.85
assert res[0]["test_loss"] <= 0.85, str(res[0]["test_loss"])
assert res[0]["test_acc"] >= 0.7, str(res[0]["test_acc"])
assert trainer.current_epoch < (max_epochs - 1)


Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_load_legacy_checkpoints(tmpdir, pl_version: str):
trainer = Trainer(default_root_dir=str(tmpdir))
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
res = trainer.test(model, datamodule=dm)
assert res[0]["test_loss"] <= 0.7
assert res[0]["test_acc"] >= 0.85
assert res[0]["test_loss"] <= 0.85, str(res[0]["test_loss"])
assert res[0]["test_acc"] >= 0.7, str(res[0]["test_acc"])
print(res)


Expand Down Expand Up @@ -111,5 +111,5 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
torch.backends.cudnn.deterministic = True
trainer.fit(model, datamodule=dm, ckpt_path=path_ckpt)
res = trainer.test(model, datamodule=dm)
assert res[0]["test_loss"] <= 0.7
assert res[0]["test_acc"] >= 0.85
assert res[0]["test_loss"] <= 0.85, str(res[0]["test_loss"])
assert res[0]["test_acc"] >= 0.7, str(res[0]["test_acc"])

0 comments on commit fee52f9

Please sign in to comment.