From 97e1d2897b417ffc8520debf36f70ccb6fba26ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 7 Mar 2022 20:21:37 +0100 Subject: [PATCH] Integrate global step with progress tracking (#11805) --- CHANGELOG.md | 9 ++++ docs/source/common/lightning_module.rst | 13 ++--- docs/source/common/trainer.rst | 25 ++++++---- .../callbacks/device_stats_monitor.py | 4 +- .../callbacks/gpu_stats_monitor.py | 4 +- pytorch_lightning/callbacks/lr_monitor.py | 4 +- .../callbacks/model_checkpoint.py | 24 +-------- .../loops/dataloader/evaluation_loop.py | 4 +- .../loops/epoch/training_epoch_loop.py | 26 +++++----- pytorch_lightning/loops/fit_loop.py | 27 +++------- .../loops/optimization/optimizer_loop.py | 9 +++- .../connectors/checkpoint_connector.py | 16 +++--- .../logger_connector/logger_connector.py | 12 ++--- pytorch_lightning/trainer/trainer.py | 6 ++- pytorch_lightning/tuner/batch_size_scaling.py | 2 - pytorch_lightning/tuner/lr_finder.py | 2 - tests/callbacks/test_lr_monitor.py | 7 +-- tests/callbacks/test_rich_progress_bar.py | 14 +++--- tests/callbacks/test_tqdm_progress_bar.py | 14 +++--- .../test_checkpoint_callback_frequency.py | 4 +- tests/checkpointing/test_model_checkpoint.py | 49 +++++-------------- .../checkpointing/test_trainer_checkpoint.py | 16 ------ tests/loggers/test_comet.py | 2 +- tests/loggers/test_mlflow.py | 4 +- tests/loggers/test_wandb.py | 4 +- tests/loops/test_loops.py | 7 ++- tests/loops/test_training_loop.py | 2 +- tests/models/test_amp.py | 4 +- tests/models/test_restore.py | 13 +++-- tests/plugins/test_checkpoint_io_plugin.py | 4 +- tests/trainer/optimization/test_optimizers.py | 2 +- tests/trainer/test_trainer.py | 3 +- 32 files changed, 144 insertions(+), 192 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 524540530b247..fd8d25ccafe6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -303,6 +303,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578)) +- The `trainer.global_step` value now accounts for multiple optimizers and TBPTT splits ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + +- The `trainer.global_step` value is now increased right after the `optimizer.step()` call which will impact users who access it during an intra-training validation hook ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + +- The filename of checkpoints created with `ModelCheckpoint(filename='{step}')` is different compared to previous versions. A checkpoint saved after 1 step will be named `step=1.ckpt` instead of `step=0.ckpt` ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + - Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index dc10f235ceb39..2156c1567ac33 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -916,7 +916,7 @@ These are properties available in a LightningModule. current_epoch ~~~~~~~~~~~~~ -The current epoch +The number of epochs run. .. code-block:: python @@ -946,12 +946,13 @@ usually do not need to use this property, but it is useful to know how to access def training_step(self, batch, batch_idx): if self.global_rank == 0: # do something only once across all the nodes - self.log("global_step", self.trainer.global_step) + ... global_step ~~~~~~~~~~~ -The current step (does not reset each epoch) +The number of optimizer steps taken (does not reset each epoch). +This includes multiple optimizers and TBPTT steps (if enabled). .. code-block:: python @@ -1003,16 +1004,16 @@ The list of loggers currently being used by the Trainer. local_rank ~~~~~~~~~~~ -The ``global_rank`` is the index of the current process across all the devices for the current node. +The ``local_rank`` is the index of the current process across all the devices for the current node. You usually do not need to use this property, but it is useful to know how to access it if needed. For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0. .. code-block:: python def training_step(self, batch, batch_idx): - if self.global_rank == 0: + if self.local_rank == 0: # do something only once across each node - self.log("global_step", self.trainer.global_step) + ... precision ~~~~~~~~~ diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index b7c65929cb7b7..1f39355afd242 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -934,7 +934,7 @@ max_steps | -Stop training after this number of steps +Stop training after this number of :ref:`global steps `. Training will stop if max_steps or max_epochs have reached (earliest). .. testcode:: @@ -959,7 +959,7 @@ min_steps | -Force training for at least these number of steps. +Force training for at least this number of :ref:`global steps `. Trainer will train model for at least min_steps or min_epochs (latest). .. testcode:: @@ -1732,16 +1732,23 @@ The metrics available to callbacks. These are automatically set when you log via current_epoch ************* -The current epoch +The number of epochs run. .. code-block:: python - def training_step(self, batch, batch_idx): - current_epoch = self.trainer.current_epoch - if current_epoch > 100: - # do something - pass + if trainer.current_epoch >= 10: + ... + +global_step +*********** + +The number of optimizer steps taken (does not reset each epoch). +This includes multiple optimizers and TBPTT steps (if enabled). +.. code-block:: python + + if trainer.global_step >= 100: + ... logger ******* @@ -1822,4 +1829,4 @@ The metrics sent to the progress bar. estimated_stepping_batches ************************** -Check out :paramref:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`. +Check out :meth:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`. diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index 93d440d016086..0929358cf0f74 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -66,7 +66,7 @@ def on_train_batch_start( for logger in trainer.loggers: separator = logger.group_separator prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) - logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_batch_end( self, @@ -88,7 +88,7 @@ def on_train_batch_end( for logger in trainer.loggers: separator = logger.group_separator prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) - logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 2e9e817bf9be4..8fb92006708f7 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -162,7 +162,7 @@ def on_train_batch_start( logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 for logger in trainer.loggers: - logger.log_metrics(logs, step=trainer.global_step) + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) @rank_zero_only def on_train_batch_end( @@ -187,7 +187,7 @@ def on_train_batch_end( logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 for logger in trainer.loggers: - logger.log_metrics(logs, step=trainer.global_step) + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) @staticmethod def _get_gpu_ids(device_ids: List[int]) -> List[str]: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 4f226f7fdec51..b149858575118 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -158,7 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) if latest_stat: for logger in trainer.loggers: - logger.log_metrics(latest_stat, step=trainer.global_step) + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self.logging_interval != "step": @@ -167,7 +167,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) if latest_stat: for logger in trainer.loggers: - logger.log_metrics(latest_stat, step=trainer.global_step) + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: latest_stat = {} diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 704d0a7a52253..d9b5f13e6fa8a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -227,7 +227,7 @@ def __init__( self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end - self._last_global_step_saved = -1 + self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score = None self.best_k_models = {} @@ -278,8 +278,7 @@ def on_train_batch_end( """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" if self._should_skip_saving_checkpoint(trainer): return - step = trainer.global_step - skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) + skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) train_time_interval = self._train_time_interval skip_time = True @@ -300,8 +299,6 @@ def on_train_batch_end( def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" - # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates - trainer.fit_loop.global_step -= 1 if ( not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end @@ -309,7 +306,6 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu and (trainer.current_epoch + 1) % self._every_n_epochs == 0 ): self.save_checkpoint(trainer) - trainer.fit_loop.global_step += 1 def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" @@ -322,22 +318,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul return self.save_checkpoint(trainer) - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Save a checkpoint when training stops. - - This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during - training/validation steps or end of epochs are not guaranteed to be available at this stage. - """ - if self._should_skip_saving_checkpoint(trainer) or not self.save_last: - return - if self.verbose: - rank_zero_info("Saving latest checkpoint...") - # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates - monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1) - trainer.fit_loop.global_step -= 1 - self._save_last_checkpoint(trainer, monitor_candidates) - trainer.fit_loop.global_step += 1 - def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> Dict[str, Any]: diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 833a43e4cf019..859ded3a98e72 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -66,8 +66,6 @@ def num_dataloaders(self) -> int: # case where user does: # return dl1, dl2 dataloaders = self.dataloaders - if dataloaders is None: - return 0 length = len(dataloaders) if length > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) @@ -78,7 +76,7 @@ def dataloaders(self) -> Sequence[DataLoader]: """Returns the validation or test dataloaders.""" dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders if dataloaders is None: - raise RuntimeError("Dataloaders should be available.") + return [] return dataloaders @property diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e28a864b46f16..67400ce0472de 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -60,7 +60,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.min_steps = min_steps self.max_steps = max_steps - self.global_step: int = 0 self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() @@ -72,6 +71,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self._warning_cache = WarningCache() # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {} + self._batches_that_stepped: int = 0 @property def total_batch_idx(self) -> int: @@ -87,6 +87,13 @@ def batch_idx(self) -> int: # but before the next `ready` increase return self.batch_progress.current.ready - 1 + @property + def global_step(self) -> int: + lightning_module = self.trainer.lightning_module + if lightning_module is None or lightning_module.automatic_optimization: + return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps + return self.batch_loop.manual_loop.optim_step_progress.total.completed + @property def _is_training_done(self) -> bool: max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps) @@ -247,17 +254,14 @@ def on_advance_end(self) -> None: self._run_validation() self.trainer.training = True - # ----------------------------------------- - # SAVE LOGGERS (ie: Tensorboard, etc...) - # ----------------------------------------- - self._save_loggers_on_train_batch_end() - # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) if not self._should_accumulate(): - # progress global step according to grads progress - self.global_step += 1 + # this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers + self._batches_that_stepped += 1 + # this will save based on the `batches_that_stepped` value + self._save_loggers_on_train_batch_end() # if training finished, defer exit to the parent. this assumes there will be enough time in between # which might not be the case depending on what's in the `*_epoch_end` hooks @@ -503,9 +507,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: def _save_loggers_on_train_batch_end(self) -> None: """Flushes loggers to disk.""" - # when loggers should save to disk - should_flush_logs = self.trainer._logger_connector.should_flush_logs - if should_flush_logs: + # this assumes that `batches_that_stepped` was increased before + should_flush = self._batches_that_stepped % self.trainer.flush_logs_every_n_steps == 0 + if should_flush or self.trainer.should_stop: for logger in self.trainer.loggers: logger.save() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 361e104fa878b..7087bcbad0442 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -68,19 +68,6 @@ def __init__( self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None - @property - def global_step(self) -> int: - """Returns the global step.""" - lightning_module = self.trainer.lightning_module - if lightning_module is None or lightning_module.automatic_optimization: - return self.epoch_loop.global_step - return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed - - @global_step.setter - def global_step(self, value: int) -> None: - """Sets the global step (forwards to epoch_loop)""" - self.epoch_loop.global_step = value - @property def total_batch_idx(self) -> int: """Returns the current batch index (across epochs)""" @@ -177,7 +164,7 @@ def _results(self) -> _ResultCollection: def done(self) -> bool: """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop - stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) + stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps) # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) @@ -186,7 +173,7 @@ def done(self) -> bool: if self.trainer.should_stop: # early stopping met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: @@ -319,14 +306,12 @@ def on_advance_end(self) -> None: self.epoch_progress.increment_completed() - # the global step is manually decreased here due to backwards compatibility with existing loggers - # as they expect that the same step is used when logging epoch end metrics even when the batch loop has - # finished. this means the attribute does not exactly track the number of optimizer steps applied. - # TODO(@carmocca): deprecate and rename so users don't get confused - self.global_step -= 1 + # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics + # even when the batch loop has finished + self.epoch_loop._batches_that_stepped -= 1 # log epoch metrics self.trainer._logger_connector.update_train_epoch_metrics() - self.global_step += 1 + self.epoch_loop._batches_that_stepped += 1 # if fault tolerant is enabled and process has been notified, exit. self.trainer._exit_gracefully_on_signal() diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index f8d692d688035..bab025466789a 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -359,7 +359,11 @@ def _optimizer_step( else: optimizer = self.trainer.strategy._lightning_optimizers[opt_idx] - self.optim_progress.optimizer.step.increment_ready() + # if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we + # need to check again if `should_accumulate` before increasing the counters + should_accumulate = self.trainer.fit_loop._should_accumulate() + if not should_accumulate: + self.optim_progress.optimizer.step.increment_ready() # model hook self.trainer._call_lightning_module_hook( @@ -374,7 +378,8 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) - self.optim_progress.optimizer.step.increment_completed() + if not should_accumulate: + self.optim_progress.optimizer.step.increment_completed() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9a15db2e8e561..be2c2d3dfa8eb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -232,16 +232,20 @@ def restore_loops(self) -> None: if not self._loaded_checkpoint: return - self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"] - # set the `current_epoch` value for old checkpoints without the progress tracking state. + fit_loop = self.trainer.fit_loop + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved - self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] + optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop + optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"] + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): - self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) elif self.trainer.state.fn == TrainerFn.TESTING: @@ -330,9 +334,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model = self.trainer.lightning_module checkpoint = { - # the epoch is saved for compatibility but it's not relevant for restoration + # the epoch and global step are saved for compatibility but they are not relevant for restoration "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step + model.automatic_optimization, + "global_step": self.trainer.global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), "loops": self._get_loops_state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 428713ff3347e..0e3a69bfc9d98 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -77,15 +77,11 @@ def on_trainer_init( ) break - @property - def should_flush_logs(self) -> bool: - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - @property def should_update_logs(self) -> bool: - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop + # `+ 1` because it can be checked before a step is executed, for example, in `on_train_batch_start` + should_log = (self.trainer.fit_loop.epoch_loop._batches_that_stepped + 1) % self.trainer.log_every_n_steps == 0 + return should_log or self.trainer.should_stop def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if isinstance(logger, bool): @@ -123,7 +119,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: if step is None: # added metrics for convenience scalar_metrics.setdefault("epoch", self.trainer.current_epoch) - step = self.trainer.global_step + step = self.trainer.fit_loop.epoch_loop._batches_that_stepped # log actual metrics for logger in self.trainer.loggers: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b3b81fe4554fd..5dbd9d3cadc3a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2486,7 +2486,11 @@ def sanity_checking(self, val: bool) -> None: @property def global_step(self) -> int: - return self.fit_loop.global_step + """The number of optimizer steps taken (does not reset each epoch). + + This includes multiple optimizers and TBPTT steps (if enabled). + """ + return self.fit_loop.epoch_loop.global_step @property def current_epoch(self) -> int: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 3d5916e3f8bd9..6f4ac72bd7e8b 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -60,9 +60,7 @@ def scale_batch_size( # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") - trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.global_step += 1 params = __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d929bbe2f87c7..36b09c130056c 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -204,9 +204,7 @@ def lr_find( # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") - trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.global_step += 1 params = __lr_finder_dump_params(trainer) # Set to values that are required by the algorithm diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 82a4a5b99894a..391e74bb10221 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -217,7 +217,6 @@ def configure_optimizers(self): optimizer2 = optim.Adam(self.parameters(), lr=1e-2) lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] model = CustomBoringModel() @@ -241,7 +240,8 @@ def configure_optimizers(self): assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": - expected_number_logged = trainer.global_step // log_every_n_steps + # divide by 2 because we have 2 optimizers + expected_number_logged = trainer.global_step // 2 // log_every_n_steps if logging_interval == "epoch": expected_number_logged = trainer.max_epochs @@ -284,7 +284,8 @@ def configure_optimizers(self): assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": - expected_number_logged = trainer.global_step // log_every_n_steps + # divide by 2 because we have 2 optimizers + expected_number_logged = trainer.global_step // 2 // log_every_n_steps if logging_interval == "epoch": expected_number_logged = trainer.max_epochs diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index cfe32dc495f8d..29ef3aa98f89b 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -368,15 +368,15 @@ def test_step(self, batch, batch_idx): trainer.fit(model) assert pbar.calls["fit"] == [ ("sanity_check", 0, 0, {"b": 0}), - ("train", 0, 0, {}), ("train", 0, 1, {}), - ("validate", 0, 1, {"b": 1}), # validation end + ("train", 0, 2, {}), + ("validate", 0, 2, {"b": 2}), # validation end # epoch end over, `on_epoch=True` metrics are computed - ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end - ("train", 1, 2, {"a": 1, "b": 1}), - ("train", 1, 3, {"a": 1, "b": 1}), - ("validate", 1, 3, {"a": 1, "b": 3}), # validation end - ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ("train", 0, 2, {"a": 1, "b": 2}), # training epoch end + ("train", 1, 3, {"a": 1, "b": 2}), + ("train", 1, 4, {"a": 1, "b": 2}), + ("validate", 1, 4, {"a": 1, "b": 4}), # validation end + ("train", 1, 4, {"a": 3, "b": 4}), # training epoch end ] trainer.validate(model, verbose=False) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 7897a1be798bb..3cfe54c992247 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -608,15 +608,15 @@ def test_step(self, batch, batch_idx): trainer.fit(model) assert pbar.calls["fit"] == [ ("sanity_check", 0, 0, {"b": 0}), - ("train", 0, 0, {}), ("train", 0, 1, {}), - ("validate", 0, 1, {"b": 1}), # validation end + ("train", 0, 2, {}), + ("validate", 0, 2, {"b": 2}), # validation end # epoch end over, `on_epoch=True` metrics are computed - ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end - ("train", 1, 2, {"a": 1, "b": 1}), - ("train", 1, 3, {"a": 1, "b": 1}), - ("validate", 1, 3, {"a": 1, "b": 3}), # validation end - ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ("train", 0, 2, {"a": 1, "b": 2}), # training epoch end + ("train", 1, 3, {"a": 1, "b": 2}), + ("train", 1, 4, {"a": 1, "b": 2}), + ("validate", 1, 4, {"a": 1, "b": 4}), # validation end + ("train", 1, 4, {"a": 3, "b": 4}), # training epoch end ] trainer.validate(model, verbose=False) diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 90665a6db476e..eeec11c6ecd14 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -81,8 +81,8 @@ def training_step(self, batch, batch_idx): trainer.fit(model) if save_last: - # last epochs are saved every step (so double the save calls) and once `on_train_end` - expected = expected * 2 + 1 + # last epochs are saved every step (so double the save calls) + expected = expected * 2 assert save_mock.call_count == expected diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b9e63d28f4234..3dadf0b733a74 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import math import os import pickle @@ -469,7 +468,7 @@ def test_model_checkpoint_file_extension(tmpdir): trainer = Trainer(default_root_dir=tmpdir, callbacks=[model_checkpoint], max_steps=1, logger=False) trainer.fit(model) - expected = ["epoch=0-step=0.tpkc", "last.tpkc"] + expected = ["epoch=0-step=1.tpkc", "last.tpkc"] assert set(expected) == set(os.listdir(tmpdir)) @@ -490,12 +489,12 @@ def test_model_checkpoint_save_last(tmpdir): ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name( - ModelCheckpoint.CHECKPOINT_NAME_LAST, {"epoch": trainer.current_epoch} + ModelCheckpoint.CHECKPOINT_NAME_LAST, {"epoch": trainer.current_epoch - 1} ) last_filename = last_filename + ".ckpt" assert str(tmpdir / last_filename) == model_checkpoint.last_model_path assert set(os.listdir(tmpdir)) == set( - [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] + [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename] ) ModelCheckpoint.CHECKPOINT_NAME_LAST = "last" @@ -583,14 +582,14 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=19.ckpt" + assert checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=20.ckpt" assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt" assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that the correct ckpts were created - expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19])] + expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])] expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @@ -642,7 +641,7 @@ def test_ckpt_every_n_train_steps(tmpdir): trainer.fit(model) expected = [ - f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) + f"step={i}.ckpt" for i in range(every_n_train_steps, max_epochs * epoch_length + 1, every_n_train_steps) ] assert set(os.listdir(tmpdir)) == set(expected) @@ -766,34 +765,14 @@ def test_default_checkpoint_behavior(tmpdir): save_weights_only = trainer.checkpoint_callback.save_weights_only save_mock.assert_has_calls( [ - call(save_dir / "epoch=0-step=4.ckpt", save_weights_only), - call(save_dir / "epoch=1-step=9.ckpt", save_weights_only), - call(save_dir / "epoch=2-step=14.ckpt", save_weights_only), + call(save_dir / "epoch=0-step=5.ckpt", save_weights_only), + call(save_dir / "epoch=1-step=10.ckpt", save_weights_only), + call(save_dir / "epoch=2-step=15.ckpt", save_weights_only), ] ) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 - assert ckpts[0] == "epoch=2-step=14.ckpt" - - -@pytest.mark.parametrize("max_epochs", [1, 2]) -@pytest.mark.parametrize("should_validate", [True, False]) -@pytest.mark.parametrize("save_last", [True, False]) -@pytest.mark.parametrize("verbose", [True, False]) -def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool -): - """Tests 'Saving latest checkpoint...' log.""" - model = LogInTwoMethods() - if not should_validate: - model.validation_step = None - ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) - trainer = Trainer( - default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, limit_train_batches=1, limit_val_batches=1 - ) - with caplog.at_level(logging.INFO): - trainer.fit(model) - assert caplog.messages.count("Saving latest checkpoint...") == (verbose and save_last) + assert ckpts[0] == "epoch=2-step=15.ckpt" def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): @@ -821,9 +800,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) - # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so the - # `current_epoch` count has not been increased yet - assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] - 1 + assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] ckpt_id = ( @@ -1041,7 +1018,7 @@ def test_val_check_interval_checkpoint_files(tmpdir): ) trainer.fit(model) files = {p.basename for p in tmpdir.listdir()} - assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]} + assert files == {f"epoch=0-step={s}.ckpt" for s in [2, 4, 6, 8, 10]} def test_current_score(tmpdir): @@ -1303,4 +1280,4 @@ def test_last_global_step_saved(): trainer = MagicMock() trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) - assert model_checkpoint._last_global_step_saved == -1 + assert model_checkpoint._last_global_step_saved == 0 diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 24268e3cfca84..5d129179c7c5d 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -71,19 +71,3 @@ def validation_step(self, batch, batch_idx): assert best_model_path.endswith(f"epoch=0{idx}.ckpt") else: assert f"epoch={idx + 1}" in best_model_path - - -def test_accumulated_gradient_batches_with_ckpt_path(tmpdir): - """This test validates that accumulated gradient is properly recomputed and reset on the trainer.""" - - ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) - model = BoringModel() - trainer_kwargs = dict( - max_epochs=1, accumulate_grad_batches={0: 2}, callbacks=ckpt, limit_train_batches=1, limit_val_batches=0 - ) - trainer = Trainer(**trainer_kwargs) - trainer.fit(model) - - trainer_kwargs["max_epochs"] = 2 - trainer = Trainer(**trainer_kwargs) - trainer.fit(model, ckpt_path=ckpt.last_model_path) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 37758e904256a..e09b954a61a6a 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -156,7 +156,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=2.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} assert trainer.log_dir == logger.save_dir diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 46c85f13e29e4..5ce5ceb75a0b1 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -136,7 +136,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): assert trainer.log_dir == logger.save_dir trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=0.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=1.ckpt"} assert trainer.log_dir == logger.save_dir @@ -177,7 +177,7 @@ def training_epoch_end(self, *args, **kwargs): assert "epoch" in os.listdir(tmpdir / exp_id / run_id / "metrics") assert set(os.listdir(tmpdir / exp_id / run_id / "params")) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / "checkpoints") - assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches - 1}.ckpt"] + assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 280303a3f7318..adb91aab6da32 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -156,7 +156,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / "project" / version / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=2.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} assert trainer.log_dir == logger.save_dir @@ -212,7 +212,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): type="model", metadata={ "score": None, - "original_filename": "epoch=1-step=5-v3.ckpt", + "original_filename": "epoch=1-step=6-v3.ckpt", "ModelCheckpoint": { "monitor": None, "mode": "min", diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index cfc347293484c..9f3c63da4d1e8 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -743,7 +743,7 @@ def test_fit_loop_reset(tmpdir): trainer.fit(model) # reset state loaded from a checkpoint from mid-epoch - mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt")) + mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop optimizer_loop = epoch_loop.batch_loop.optimizer_loop @@ -776,7 +776,7 @@ def test_fit_loop_reset(tmpdir): assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch - end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt")) + end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=4.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop fit_loop.restarting = False @@ -943,8 +943,7 @@ def val_dataloader(self): ) trainer.fit(model, ckpt_path=ckpt_path) - # TODO: -1 because there's a bug where global step is off by one on reload - assert trainer.global_step - 1 == expected_global_step + assert trainer.global_step == expected_global_step state_dict_after_restart = trainer.fit_loop.state_dict() diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index bcec1bb8bc13f..3de02d5f8bb1c 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -133,7 +133,7 @@ def validation_step(self, *args): # even though we stopped mid epoch, the fit loop finished normally and the current epoch was increased assert trainer.current_epoch == 1 assert trainer.global_step == 5 - assert model.validation_called_at == (0, 4) + assert model.validation_called_at == (0, 5) def test_warning_valid_train_step_end(tmpdir): diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 17135b98c16f5..917bb4d224194 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -199,7 +199,9 @@ def configure_optimizers(self): assert str(trainer.amp_backend) == "AMPType.APEX" trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert bwd_mock.call_count == 10 + # `max_steps` is fulfilled in the third batch first optimizer, but we don't check the loop + # `done` condition until all optimizers have run, so the number of backwards is higher than `max_steps` + assert bwd_mock.call_count == 6 assert isinstance(trainer.lr_scheduler_configs[0].scheduler.optimizer, optim.Adam) assert isinstance(trainer.lr_scheduler_configs[1].scheduler.optimizer, optim.SGD) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index c04e36bbc09bd..e5259c4047ad2 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,7 +199,9 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - assert self.trainer.current_epoch == state_dict["epoch"] + # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so + # the `current_epoch` count has not been increased yet + assert self.trainer.current_epoch - 1 == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() @@ -241,8 +243,7 @@ def test_correct_step_and_epoch(tmpdir): ckpt = torch.load(ckpt_path) assert ckpt["epoch"] == first_max_epochs - # TODO(@carmocca): should not need `+1` - assert ckpt["global_step"] == first_max_epochs * train_batches + 1 + assert ckpt["global_step"] == first_max_epochs * train_batches max_epochs = first_max_epochs + 2 trainer = Trainer( @@ -255,13 +256,11 @@ def test_correct_step_and_epoch(tmpdir): class TestModel(BoringModel): def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs - # TODO(@carmocca): should not need `+1` - assert self.trainer.global_step == first_max_epochs * train_batches + 1 + assert self.trainer.global_step == first_max_epochs * train_batches trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs - # TODO(@carmocca): should not need `+1` - assert trainer.global_step == max_epochs * train_batches + 1 + assert trainer.global_step == max_epochs * train_batches def test_fit_twice(tmpdir): diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 7a1352804ba3d..56aadad353b2e 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -52,7 +52,7 @@ def test_checkpoint_plugin_called(tmpdir): ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) @@ -71,7 +71,7 @@ def test_checkpoint_plugin_called(tmpdir): ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 99071ce3d8f8a..38c0a83cabb65 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -628,7 +628,7 @@ def configure_optimizers(self): def on_save_checkpoint(self, checkpoint): lr_scheduler_config = checkpoint["lr_schedulers"][0] # 2 batches ran. since the lr_scheduler_config interval is `step`, the step count should be 2 - assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet + assert self.trainer.global_step == batches compare_to = max_epochs if epoch_interval else batches assert lr_scheduler_config["_step_count"] - 1 == compare_to # step count starts at 1 assert lr_scheduler_config["_last_lr"] == [lr * gamma ** compare_to] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d5f0aeea1e0e4..6f4d7300220e5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -332,7 +332,8 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses, 1): - trainer.fit_loop.global_step = i + # sets `trainer.global_step` + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`