Skip to content

Commit

Permalink
Integrate global step with progress tracking (Lightning-AI#11805)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and Borda committed Mar 10, 2022
1 parent 28bc4f0 commit 97e1d28
Show file tree
Hide file tree
Showing 32 changed files with 144 additions and 192 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578))


- The `trainer.global_step` value now accounts for multiple optimizers and TBPTT splits ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805))


- The `trainer.global_step` value is now increased right after the `optimizer.step()` call which will impact users who access it during an intra-training validation hook ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805))


- The filename of checkpoints created with `ModelCheckpoint(filename='{step}')` is different compared to previous versions. A checkpoint saved after 1 step will be named `step=1.ckpt` instead of `step=0.ckpt` ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805))


- Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521))


Expand Down
13 changes: 7 additions & 6 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ These are properties available in a LightningModule.
current_epoch
~~~~~~~~~~~~~

The current epoch
The number of epochs run.

.. code-block:: python
Expand Down Expand Up @@ -946,12 +946,13 @@ usually do not need to use this property, but it is useful to know how to access
def training_step(self, batch, batch_idx):
if self.global_rank == 0:
# do something only once across all the nodes
self.log("global_step", self.trainer.global_step)
...
global_step
~~~~~~~~~~~

The current step (does not reset each epoch)
The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers and TBPTT steps (if enabled).

.. code-block:: python
Expand Down Expand Up @@ -1003,16 +1004,16 @@ The list of loggers currently being used by the Trainer.
local_rank
~~~~~~~~~~~

The ``global_rank`` is the index of the current process across all the devices for the current node.
The ``local_rank`` is the index of the current process across all the devices for the current node.
You usually do not need to use this property, but it is useful to know how to access it if needed.
For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.

.. code-block:: python
def training_step(self, batch, batch_idx):
if self.global_rank == 0:
if self.local_rank == 0:
# do something only once across each node
self.log("global_step", self.trainer.global_step)
...
precision
~~~~~~~~~
Expand Down
25 changes: 16 additions & 9 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ max_steps

|
Stop training after this number of steps
Stop training after this number of :ref:`global steps <common/trainer:global_step>`.
Training will stop if max_steps or max_epochs have reached (earliest).

.. testcode::
Expand All @@ -959,7 +959,7 @@ min_steps

|
Force training for at least these number of steps.
Force training for at least this number of :ref:`global steps <common/trainer:global_step>`.
Trainer will train model for at least min_steps or min_epochs (latest).

.. testcode::
Expand Down Expand Up @@ -1732,16 +1732,23 @@ The metrics available to callbacks. These are automatically set when you log via
current_epoch
*************

The current epoch
The number of epochs run.

.. code-block:: python
def training_step(self, batch, batch_idx):
current_epoch = self.trainer.current_epoch
if current_epoch > 100:
# do something
pass
if trainer.current_epoch >= 10:
...
global_step
***********

The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers and TBPTT steps (if enabled).

.. code-block:: python
if trainer.global_step >= 100:
...
logger
*******
Expand Down Expand Up @@ -1822,4 +1829,4 @@ The metrics sent to the progress bar.
estimated_stepping_batches
**************************

Check out :paramref:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`.
Check out :meth:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`.
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def on_train_batch_start(
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

def on_train_batch_end(
self,
Expand All @@ -88,7 +88,7 @@ def on_train_batch_end(
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def on_train_batch_start(
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)
logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

@rank_zero_only
def on_train_batch_end(
Expand All @@ -187,7 +187,7 @@ def on_train_batch_end(
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)
logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)

if latest_stat:
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)
logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if self.logging_interval != "step":
Expand All @@ -167,7 +167,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)

if latest_stat:
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)
logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}
Expand Down
24 changes: 2 additions & 22 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
self.save_weights_only = save_weights_only
self.auto_insert_metric_name = auto_insert_metric_name
self._save_on_train_epoch_end = save_on_train_epoch_end
self._last_global_step_saved = -1
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -278,8 +278,7 @@ def on_train_batch_end(
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
if self._should_skip_saving_checkpoint(trainer):
return
step = trainer.global_step
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)

train_time_interval = self._train_time_interval
skip_time = True
Expand All @@ -300,16 +299,13 @@ def on_train_batch_end(

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
trainer.fit_loop.global_step -= 1
if (
not self._should_skip_saving_checkpoint(trainer)
and self._save_on_train_epoch_end
and self._every_n_epochs > 0
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
self.save_checkpoint(trainer)
trainer.fit_loop.global_step += 1

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
Expand All @@ -322,22 +318,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
return
self.save_checkpoint(trainer)

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint when training stops.
This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during
training/validation steps or end of epochs are not guaranteed to be available at this stage.
"""
if self._should_skip_saving_checkpoint(trainer) or not self.save_last:
return
if self.verbose:
rank_zero_info("Saving latest checkpoint...")
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
trainer.fit_loop.global_step -= 1
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.fit_loop.global_step += 1

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def num_dataloaders(self) -> int:
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
if dataloaders is None:
return 0
length = len(dataloaders)
if length > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
Expand All @@ -78,7 +76,7 @@ def dataloaders(self) -> Sequence[DataLoader]:
"""Returns the validation or test dataloaders."""
dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders
if dataloaders is None:
raise RuntimeError("Dataloaders should be available.")
return []
return dataloaders

@property
Expand Down
26 changes: 15 additions & 11 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.min_steps = min_steps
self.max_steps = max_steps

self.global_step: int = 0
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

Expand All @@ -72,6 +71,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self._warning_cache = WarningCache()
# caches the loaded dataloader state until dataloader objects are available
self._dataloader_state_dict: Dict[str, Any] = {}
self._batches_that_stepped: int = 0

@property
def total_batch_idx(self) -> int:
Expand All @@ -87,6 +87,13 @@ def batch_idx(self) -> int:
# but before the next `ready` increase
return self.batch_progress.current.ready - 1

@property
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
return self.batch_loop.manual_loop.optim_step_progress.total.completed

@property
def _is_training_done(self) -> bool:
max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps)
Expand Down Expand Up @@ -247,17 +254,14 @@ def on_advance_end(self) -> None:
self._run_validation()
self.trainer.training = True

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self._save_loggers_on_train_batch_end()

# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers("step", update_plateau_schedulers=True)

if not self._should_accumulate():
# progress global step according to grads progress
self.global_step += 1
# this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers
self._batches_that_stepped += 1
# this will save based on the `batches_that_stepped` value
self._save_loggers_on_train_batch_end()

# if training finished, defer exit to the parent. this assumes there will be enough time in between
# which might not be the case depending on what's in the `*_epoch_end` hooks
Expand Down Expand Up @@ -503,9 +507,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:

def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer._logger_connector.should_flush_logs
if should_flush_logs:
# this assumes that `batches_that_stepped` was increased before
should_flush = self._batches_that_stepped % self.trainer.flush_logs_every_n_steps == 0
if should_flush or self.trainer.should_stop:
for logger in self.trainer.loggers:
logger.save()

Expand Down
27 changes: 6 additions & 21 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def __init__(
self._outputs: _EPOCH_OUTPUTS_TYPE = []
self._data_fetcher: Optional[AbstractDataFetcher] = None

@property
def global_step(self) -> int:
"""Returns the global step."""
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.epoch_loop.global_step
return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed

@global_step.setter
def global_step(self, value: int) -> None:
"""Sets the global step (forwards to epoch_loop)"""
self.epoch_loop.global_step = value

@property
def total_batch_idx(self) -> int:
"""Returns the current batch index (across epochs)"""
Expand Down Expand Up @@ -177,7 +164,7 @@ def _results(self) -> _ResultCollection:
def done(self) -> bool:
"""Evaluates when to leave the loop."""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
Expand All @@ -186,7 +173,7 @@ def done(self) -> bool:
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
else:
Expand Down Expand Up @@ -319,14 +306,12 @@ def on_advance_end(self) -> None:

self.epoch_progress.increment_completed()

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
# TODO(@carmocca): deprecate and rename so users don't get confused
self.global_step -= 1
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
# even when the batch loop has finished
self.epoch_loop._batches_that_stepped -= 1
# log epoch metrics
self.trainer._logger_connector.update_train_epoch_metrics()
self.global_step += 1
self.epoch_loop._batches_that_stepped += 1

# if fault tolerant is enabled and process has been notified, exit.
self.trainer._exit_gracefully_on_signal()
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def _optimizer_step(
else:
optimizer = self.trainer.strategy._lightning_optimizers[opt_idx]

self.optim_progress.optimizer.step.increment_ready()
# if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we
# need to check again if `should_accumulate` before increasing the counters
should_accumulate = self.trainer.fit_loop._should_accumulate()
if not should_accumulate:
self.optim_progress.optimizer.step.increment_ready()

# model hook
self.trainer._call_lightning_module_hook(
Expand All @@ -374,7 +378,8 @@ def _optimizer_step(
using_lbfgs=is_lbfgs,
)

self.optim_progress.optimizer.step.increment_completed()
if not should_accumulate:
self.optim_progress.optimizer.step.increment_completed()

def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.
Expand Down
Loading

0 comments on commit 97e1d28

Please sign in to comment.