Skip to content

Commit

Permalink
Fix CheckpointSaver log error (#6026)
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Feb 18, 2023
1 parent bf55f22 commit f5708ea
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import logging
import os
import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
self._key_metric_checkpoint: Checkpoint | None = None
self._interval_checkpoint: Checkpoint | None = None
self._name = name
self._final_filename = final_filename

class _DiskSaver(DiskSaver):
"""
Expand Down Expand Up @@ -148,7 +150,7 @@ def _final_func(engine: Engine) -> Any:

self._final_checkpoint = Checkpoint(
to_save=self.save_dict,
save_handler=_DiskSaver(dirname=self.save_dir, filename=final_filename),
save_handler=_DiskSaver(dirname=self.save_dir, filename=self._final_filename),
filename_prefix=file_prefix,
score_function=_final_func,
score_name="final_iteration",
Expand Down Expand Up @@ -271,7 +273,11 @@ def completed(self, engine: Engine) -> None:
raise AssertionError
if not hasattr(self.logger, "info"):
raise AssertionError("Error, provided logger has not info attribute.")
self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}")
if self._final_filename is not None:
_final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)
else:
_final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment]
self.logger.info(f"Train completed, saved final checkpoint: {_final_checkpoint_path}")

def exception_raised(self, engine: Engine, e: Exception) -> None:
"""Callback for train or validation/evaluation exception raised Event.
Expand All @@ -291,7 +297,11 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
raise AssertionError
if not hasattr(self.logger, "info"):
raise AssertionError("Error, provided logger has not info attribute.")
self.logger.info(f"Exception raised, saved the last checkpoint: {self._final_checkpoint.last_checkpoint}")
if self._final_filename is not None:
_final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)
else:
_final_checkpoint_path = self._final_checkpoint.last_checkpoint # type: ignore[assignment]
self.logger.info(f"Exception raised, saved the last checkpoint: {_final_checkpoint_path}")
raise e

def metrics_completed(self, engine: Engine) -> None:
Expand Down

0 comments on commit f5708ea

Please sign in to comment.