diff --git a/pyproject.toml b/pyproject.toml index 989e63122f640..b6cbbbda15006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ warn_no_return = "False" # the list can be generated with: # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.callbacks.stochastic_weight_avg", diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 2fd730482fcc4..72d8445d84407 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -135,7 +135,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def _validate_condition_metric(self, logs: Dict[str, float]) -> bool: + def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index bb6d0a9a9b0b6..9b49b9d44bb10 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -39,7 +39,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _name, _version from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache log = logging.getLogger(__name__) @@ -231,13 +231,14 @@ def __init__( self._save_on_train_epoch_end = save_on_train_epoch_end self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None - self.current_score = None - self.best_k_models = {} + self.current_score: Optional[Tensor] = None + self.best_k_models: Dict[str, Tensor] = {} self.kth_best_model_path = "" - self.best_model_score = None + self.best_model_score: Optional[Tensor] = None self.best_model_path = "" self.last_model_path = "" + self.kth_value: Tensor self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) @@ -256,6 +257,7 @@ def state_key(self) -> str: def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: self.__resolve_ckpt_dir(trainer) + assert self.dirpath is not None if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) @@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: if self.save_top_k == 0: return @@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: from pytorch_lightning.trainer.states import TrainerFn return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run + bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == trainer.global_step # already saved at the last step @@ -493,7 +495,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) # If using multiple devices, make sure all processes are unanimous on the decision. - should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save) + should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) return should_update_best_and_save @@ -501,7 +503,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = def _format_checkpoint_name( cls, filename: Optional[str], - metrics: Dict[str, _METRIC], + metrics: Dict[str, Tensor], prefix: str = "", auto_insert_metric_name: bool = True, ) -> str: @@ -522,7 +524,7 @@ def _format_checkpoint_name( filename = filename.replace(group, f"{{0[{name}]") if name not in metrics: - metrics[name] = 0 + metrics[name] = torch.tensor(0) filename = filename.format(metrics) if prefix: @@ -531,7 +533,7 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None + self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints") elif trainer.loggers: if len(trainer.loggers) == 1: + assert trainer.logger is not None save_dir = trainer.logger.save_dir or trainer.default_root_dir else: save_dir = trainer.default_root_dir @@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: monitor_candidates = deepcopy(trainer.callback_metrics) # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor # or does not exist we overwrite it as it's likely an error @@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) return monitor_candidates - def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: if not self.save_last: return @@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ if previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) - def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + assert self.monitor current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): + assert current is not None self._update_best_and_save(current, trainer, monitor_candidates) elif self.verbose: epoch = monitor_candidates["epoch"] step = monitor_candidates["step"] rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath @@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate trainer.strategy.remove_checkpoint(previous) def _update_best_and_save( - self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC] + self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k @@ -691,11 +696,11 @@ def _update_best_and_save( if len(self.best_k_models) == k: # monitor dict has reached k elements _op = max if self.mode == "min" else min - self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.kth_value = self.best_k_models[self.kth_best_model_path] _op = min if self.mode == "min" else max - self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: @@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None: file.""" best_k = {k: v.item() for k, v in self.best_k_models.items()} if filepath is None: + assert self.dirpath filepath = os.path.join(self.dirpath, "best_k_models.yaml") with self._fs.open(filepath, "w") as fp: yaml.dump(best_k, fp) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 4b8770dd89a4e..d07b272c171c7 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -532,7 +532,7 @@ def __to_tensor(self, value: numbers.Number) -> Tensor: return torch.tensor(value, device=self.device) @staticmethod - def __check_numel_1(value: torch.Tensor, name: str) -> None: + def __check_numel_1(value: Tensor, name: str) -> None: if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 2cbf14760f83f..0a9b19376bcd4 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -285,7 +285,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """ def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" + """Reduce a boolean decision across all processes.""" return decision def pre_backward(self, closure_loss: Tensor) -> None: diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 178fe638cc0a3..b27f299fb3722 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -169,19 +169,13 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.root_device) - decision = self.reduce(decision, reduce_op="sum") - decision = bool(decision == self.world_size) - return decision - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) - _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if _invalid_reduce_op or _invalid_reduce_op_str: + invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if invalid_reduce_op or invalid_reduce_op_str: raise MisconfigurationException( "Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation." ) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py index a33359a3fe5e9..27cb3cb0323b2 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -529,7 +529,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None: - if not isinstance(cache, torch.Tensor): + if not isinstance(cache, Tensor): raise ValueError( f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor." f" Found {cache}" diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 882326f870de6..b53e19a11e9f6 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -2705,7 +2705,9 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None: self._loggers = loggers if loggers else [] @property - def callback_metrics(self) -> dict: + def callback_metrics(self) -> Dict[str, Tensor]: + # TODO: the true typing return can include dictionaries as defined in + # `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT` return self._logger_connector.callback_metrics @property diff --git a/src/pytorch_lightning/utilities/distributed.py b/src/pytorch_lightning/utilities/distributed.py index bc7ed3debaf90..361c6dd12beeb 100644 --- a/src/pytorch_lightning/utilities/distributed.py +++ b/src/pytorch_lightning/utilities/distributed.py @@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens return gathered_result -def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) return gathered_result