Skip to content

Commit

Permalink
Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoin…
Browse files Browse the repository at this point in the history
…t.py (#13617)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
Jungwon-Lee and carmocca authored Jul 20, 2022
1 parent ca1917e commit b40766c
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 34 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
# validation, then we run after validation instead of on train epoch end
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1

def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
monitor_val = logs.get(self.monitor)

error_msg = (
Expand Down
42 changes: 24 additions & 18 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -231,13 +231,14 @@ def __init__(
self._save_on_train_epoch_end = save_on_train_epoch_end
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score = None
self.best_k_models = {}
self.current_score: Optional[Tensor] = None
self.best_k_models: Dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score = None
self.best_model_score: Optional[Tensor] = None
self.best_model_path = ""
self.last_model_path = ""

self.kth_value: Tensor
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
Expand All @@ -256,6 +257,7 @@ def state_key(self) -> str:

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self.__resolve_ckpt_dir(trainer)
assert self.dirpath is not None
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

Expand Down Expand Up @@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if self.save_top_k == 0:
return

Expand Down Expand Up @@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn

return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
Expand Down Expand Up @@ -493,15 +495,15 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

return should_update_best_and_save

@classmethod
def _format_checkpoint_name(
cls,
filename: Optional[str],
metrics: Dict[str, _METRIC],
metrics: Dict[str, Tensor],
prefix: str = "",
auto_insert_metric_name: bool = True,
) -> str:
Expand All @@ -522,7 +524,7 @@ def _format_checkpoint_name(
filename = filename.replace(group, f"{{0[{name}]")

if name not in metrics:
metrics[name] = 0
metrics[name] = torch.tensor(0)
filename = filename.format(metrics)

if prefix:
Expand All @@ -531,7 +533,7 @@ def _format_checkpoint_name(
return filename

def format_checkpoint_name(
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
Expand Down Expand Up @@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
elif trainer.loggers:
if len(trainer.loggers) == 1:
assert trainer.logger is not None
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir
Expand All @@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def _get_metric_interpolated_filepath_name(
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(monitor_candidates)

Expand All @@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name(

return filepath

def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
monitor_candidates = deepcopy(trainer.callback_metrics)
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
# or does not exist we overwrite it as it's likely an error
Expand All @@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
return monitor_candidates

def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if not self.save_last:
return

Expand All @@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)

def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
assert self.monitor
current = monitor_candidates.get(self.monitor)
if self.check_monitor_top_k(trainer, current):
assert current is not None
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")

def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
Expand All @@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
trainer.strategy.remove_checkpoint(previous)

def _update_best_and_save(
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

Expand All @@ -691,11 +696,11 @@ def _update_best_and_save(
if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.kth_value = self.best_k_models[self.kth_best_model_path]

_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
Expand All @@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
if filepath is None:
assert self.dirpath
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __to_tensor(self, value: numbers.Number) -> Tensor:
return torch.tensor(value, device=self.device)

@staticmethod
def __check_numel_1(value: torch.Tensor, name: str) -> None:
def __check_numel_1(value: Tensor, name: str) -> None:
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
"""Reduce a boolean decision across all processes."""
return decision

def pre_backward(self, closure_loss: Tensor) -> None:
Expand Down
12 changes: 3 additions & 9 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,13 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, Tensor):
output = torch.tensor(output, device=self.root_device)

_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
result_metric.meta.sync.should = should
cache = result_metric._computed
if cache is not None:
if not isinstance(cache, torch.Tensor):
if not isinstance(cache, Tensor):
raise ValueError(
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
f" Found {cache}"
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,9 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None:
self._loggers = loggers if loggers else []

@property
def callback_metrics(self) -> dict:
def callback_metrics(self) -> Dict[str, Tensor]:
# TODO: the true typing return can include dictionaries as defined in
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
return self._logger_connector.callback_metrics

@property
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
return gathered_result


def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result
Expand Down

0 comments on commit b40766c

Please sign in to comment.