Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TQDM progress bar tracking with multiple dataloaders #11657

Merged
merged 25 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))


- Fixed `TQDMProgressBar` counter when using multple validation dataloaders ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657))


## [1.5.8] - 2022-01-05

### Fixed
Expand Down
48 changes: 39 additions & 9 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,28 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):

def __init__(self) -> None:
self._trainer: Optional["pl.Trainer"] = None
self._current_eval_dataloader_idx: Optional[int] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def trainer(self) -> "pl.Trainer":
if self._trainer is None:
raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.")
return self._trainer

def is_dataloader_changed(self, dataloader_idx: int) -> bool:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
return old_dataloader_idx != dataloader_idx

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None

@property
def train_batch_idx(self) -> int:
"""The number of batches processed during training.
Expand All @@ -70,8 +85,12 @@ def val_batch_idx(self) -> int:
Use this to update your progress bar.
"""
if self.trainer.state.fn == "fit":
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.fit_loop.epoch_loop.val_loop
else:
loop = self.trainer.validate_loop

current_batch_idx = loop.epoch_loop.batch_progress.current.processed
return current_batch_idx

@property
def test_batch_idx(self) -> int:
Expand Down Expand Up @@ -105,12 +124,11 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
assert self._current_eval_dataloader_idx is not None
if self.trainer.sanity_checking:
return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx]

return total_val_batches
return self.trainer.num_val_batches[self._current_eval_dataloader_idx]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def total_test_batches(self) -> Union[int, float]:
Expand All @@ -119,7 +137,8 @@ def total_test_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
of infinite size.
"""
return sum(self.trainer.num_test_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_test_batches[self._current_eval_dataloader_idx]

@property
def total_predict_batches(self) -> Union[int, float]:
Expand All @@ -128,7 +147,18 @@ def total_predict_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
"""
return sum(self.trainer.num_predict_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_predict_batches[self._current_eval_dataloader_idx]

@property
def num_val_batches(self) -> Union[int, float]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if (
self.trainer.enable_validation
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
):
return sum(self.trainer.num_val_batches)

return 0

def disable(self) -> None:
"""You should provide a way to disable the progress bar.
Expand Down
36 changes: 26 additions & 10 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import timedelta
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE

Expand Down Expand Up @@ -325,7 +326,7 @@ def on_sanity_check_end(self, trainer, pl_module):

def on_train_epoch_start(self, trainer, pl_module):
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
total_val_batches = sum(self.trainer.num_val_batches)
if total_train_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
Expand All @@ -345,8 +346,13 @@ def on_train_epoch_start(self, trainer, pl_module):
)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
if self.total_val_batches > 0:
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
if self.val_progress_bar_id is not None:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.progress.update(self.val_progress_bar_id, advance=self.refresh_rate, visible=False)

total_val_batches = self.total_val_batches
if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"):
# val can be checked multiple times per epoch
Expand All @@ -369,17 +375,27 @@ def _update(self, progress_bar_id: int, current: int, total: int, visible: bool
def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)

def on_test_epoch_start(self, trainer, pl_module):
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()
def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
if self.test_progress_bar_id is not None:
self.progress.update(self.test_progress_bar_id, advance=self.refresh_rate, visible=False)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()
def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
if self.predict_progress_bar_id is not None:
self.progress.update(self.predict_progress_bar_id, advance=self.refresh_rate, visible=False)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
Expand Down
99 changes: 65 additions & 34 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()}

@property
def sanity_check_description(self) -> str:
return "Validation Sanity Check"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def train_description(self) -> str:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return "Training"

@property
def validation_description(self) -> str:
return "Validation"

@property
def test_description(self) -> str:
return "Testing"

@property
def predict_description(self) -> str:
return "Predicting"

@property
def main_progress_bar(self) -> _tqdm:
if self._main_progress_bar is None:
Expand Down Expand Up @@ -173,10 +193,7 @@ def is_disabled(self) -> bool:

@property
def _val_processed(self) -> int:
if self.trainer.state.fn == "fit":
# use total in case validation runs more than once per training epoch
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

def disable(self) -> None:
self._enabled = False
Expand All @@ -187,7 +204,7 @@ def enable(self) -> None:
def init_sanity_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for the validation sanity run."""
bar = Tqdm(
desc="Validation sanity check",
desc=self.sanity_check_description,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=False,
Expand All @@ -199,7 +216,7 @@ def init_sanity_tqdm(self) -> Tqdm:
def init_train_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for training."""
bar = Tqdm(
desc="Training",
desc=self.train_description,
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
Expand All @@ -213,7 +230,7 @@ def init_train_tqdm(self) -> Tqdm:
def init_predict_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for predicting."""
bar = Tqdm(
desc="Predicting",
desc=self.predict_description,
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
Expand All @@ -227,12 +244,12 @@ def init_predict_tqdm(self) -> Tqdm:
def init_validation_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for validation."""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self._main_progress_bar is not None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
has_main_bar = not self.trainer.state.fn == "validate"
bar = Tqdm(
desc="Validating",
desc=self.validation_description,
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
leave=not has_main_bar,
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
dynamic_ncols=True,
file=sys.stdout,
)
Expand All @@ -256,14 +273,15 @@ def on_sanity_check_start(self, *_: Any) -> None:

def on_sanity_check_end(self, *_: Any) -> None:
self.main_progress_bar.close()
self.main_progress_bar = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.val_progress_bar.close()

def on_train_start(self, *_: Any) -> None:
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
total_val_batches = self.num_val_batches
if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
Expand All @@ -273,62 +291,75 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
if self._should_update(self.train_batch_idx):
if self._should_update(self.train_batch_idx, self.total_train_batches):
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not self.main_progress_bar.disable:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, *_: Any) -> None:
self.main_progress_bar.close()

def on_validation_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if trainer.sanity_checking:
self.val_progress_bar.total = sum(trainer.num_sanity_val_batches)
else:
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.sanity_checking:
self.val_progress_bar = self.init_validation_tqdm()

def on_validation_batch_start(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
self.val_progress_bar.total = convert_inf(self.total_val_batches)
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx):
if self._should_update(self.val_batch_idx, self.total_val_batches):
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == "fit":
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

def on_validation_epoch_end(self, *_: Any) -> None:
_update_n(self.val_progress_bar, self._val_processed)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()
super().on_validation_end(trainer, pl_module)

def on_test_start(self, *_: Any) -> None:
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_start(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
self.test_progress_bar.total = convert_inf(self.total_test_batches)
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx):
if self._should_update(self.test_batch_idx, self.total_test_batches):
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_epoch_end(self, *_: Any) -> None:
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, *_: Any) -> None:
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
super().on_test_end(trainer, pl_module)

def on_predict_epoch_start(self, *_: Any) -> None:
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
super().on_predict_end(trainer, pl_module)

def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.is_dataloader_changed(dataloader_idx):
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx):
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, *_: Any) -> None:
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()

def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
Expand All @@ -347,8 +378,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
s = sep.join(map(str, args))
active_progress_bar.write(s, **kwargs)

def _should_update(self, idx: int) -> bool:
return self.refresh_rate > 0 and idx % self.refresh_rate == 0
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def _get_max_batches(self) -> List[int]:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,9 @@ def _run_sanity_check(self) -> None:

# reload dataloaders
val_loop._reload_evaluation_dataloaders()
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]

# run eval step
with torch.no_grad():
Expand Down
Loading