Skip to content

Commit

Permalink
Fix last and best checkpoints issue (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 authored May 24, 2024
1 parent db92cd5 commit 0e43ba5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
19 changes: 19 additions & 0 deletions ascent/configs/callbacks/latest_checkpoint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.ModelCheckpoint.html

# Save the model periodically by monitoring a quantity.
# Look at the above link for more detailed information.
latest_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: null # directory to save the model file
filename: "latest_epoch_{epoch:03d}" # checkpoint filename
monitor: "epoch" # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the model’s weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: 50 # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
6 changes: 5 additions & 1 deletion ascent/configs/callbacks/nnunet.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
defaults:
- model_checkpoint
- latest_checkpoint
- model_summary
- rich_progress_bar
- learning_rate_monitor
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
filename: "best_epoch_{epoch:03d}"
monitor: "val/dice_MA"
mode: "max"
save_last: True
auto_insert_metric_name: False

latest_checkpoint:
dirpath: ${paths.output_dir}/checkpoints

model_summary:
max_depth: -1
33 changes: 23 additions & 10 deletions ascent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def run_system(cfg: DictConfig) -> Tuple[dict, dict]:
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))

# Modify the last checkpoint name to differentiate between the best and last ckpt
for list_id, callback in enumerate(callbacks):
if isinstance(callback, pl.pytorch.callbacks.ModelCheckpoint):
if "best" in callback.filename:
callback.CHECKPOINT_NAME_LAST = "best"
elif "latest" in callback.filename:
callback.CHECKPOINT_NAME_LAST = "last"

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))

Expand Down Expand Up @@ -101,29 +109,34 @@ def run_system(cfg: DictConfig) -> Tuple[dict, dict]:
else:
trainer.fit(model=model, datamodule=datamodule)

# locate best and last ckpt path
best_model_path = None
last_model_path = None
for callback in trainer.checkpoint_callbacks:
if callback.CHECKPOINT_NAME_LAST == "best":
best_model_path = callback.last_model_path
elif "latest" in callback.filename:
last_model_path = callback.last_model_path

if isinstance(trainer.logger, CometLogger) and cfg.comet_save_model:
if trainer.checkpoint_callback.best_model_path:
trainer.logger.experiment.log_model(
"best-model", trainer.checkpoint_callback.best_model_path
)
if trainer.checkpoint_callback.last_model_path:
trainer.logger.experiment.log_model(
"last-model", trainer.checkpoint_callback.last_model_path
)
if best_model_path:
trainer.logger.experiment.log_model("best-model", best_model_path)
if last_model_path:
trainer.logger.experiment.log_model("last-model", last_model_path)

train_metrics = trainer.callback_metrics

if cfg.get("test"):
log.info("Starting testing!")
if cfg.get("best_model"):
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt_path = best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
else:
log.info(f"Loading best ckpt: {ckpt_path}")
else:
ckpt_path = trainer.checkpoint_callback.last_model_path
ckpt_path = last_model_path
if ckpt_path == "":
log.warning("Last ckpt not found! Using current weights for testing...")
ckpt_path = None
Expand Down

0 comments on commit 0e43ba5

Please sign in to comment.