Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TQDM progress bar tracking with multiple dataloaders #11657

Merged
merged 25 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `parallel_devices` property in `ParallelStrategy` to be lazy initialized ([#11572](https://github.com/PyTorchLightning/pytorch-lightning/pull/11572))


- Updated `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657))


- Sorted `SimpleProfiler(extended=False)` summary based on mean duration for each hook ([#11671](https://github.com/PyTorchLightning/pytorch-lightning/pull/11671))


Expand Down
75 changes: 58 additions & 17 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,34 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):

def __init__(self) -> None:
self._trainer: Optional["pl.Trainer"] = None
self._current_eval_dataloader_idx: Optional[int] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def trainer(self) -> "pl.Trainer":
if self._trainer is None:
raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.")
return self._trainer

@property
def sanity_check_description(self) -> str:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return "Sanity Checking"

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
@property
def train_description(self) -> str:
return "Training"

@property
def validation_description(self) -> str:
return "Validation"

@property
def test_description(self) -> str:
return "Testing"

@property
def predict_description(self) -> str:
return "Predicting"

@property
def train_batch_idx(self) -> int:
"""The number of batches processed during training.
Expand All @@ -71,8 +92,12 @@ def val_batch_idx(self) -> int:
Use this to update your progress bar.
"""
if self.trainer.state.fn == "fit":
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.fit_loop.epoch_loop.val_loop
else:
loop = self.trainer.validate_loop

current_batch_idx = loop.epoch_loop.batch_progress.current.processed
return current_batch_idx

@property
def test_batch_idx(self) -> int:
Expand Down Expand Up @@ -100,39 +125,55 @@ def total_train_batches(self) -> Union[int, float]:
return self.trainer.num_training_batches

@property
def total_val_batches(self) -> Union[int, float]:
"""The total number of validation batches, which may change from epoch to epoch.
def total_val_batches_current_dataloader(self) -> Union[int, float]:
"""The total number of validation batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
assert self._current_eval_dataloader_idx is not None
if self.trainer.sanity_checking:
return sum(self.trainer.num_sanity_val_batches)

total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx]

return total_val_batches
return self.trainer.num_val_batches[self._current_eval_dataloader_idx]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def total_test_batches(self) -> Union[int, float]:
"""The total number of testing batches, which may change from epoch to epoch.
def total_test_batches_current_dataloader(self) -> Union[int, float]:
"""The total number of testing batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
of infinite size.
"""
return sum(self.trainer.num_test_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_test_batches[self._current_eval_dataloader_idx]

@property
def total_predict_batches(self) -> Union[int, float]:
"""The total number of prediction batches, which may change from epoch to epoch.
def total_predict_batches_current_dataloader(self) -> Union[int, float]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""The total number of prediction batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
"""
return sum(self.trainer.num_predict_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_predict_batches[self._current_eval_dataloader_idx]

@property
def total_val_batches(self) -> Union[int, float]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""The total number of validation batches, which may change from epoch to epoch for all val dataloaders.

Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
"""
assert self._trainer is not None
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0

def has_dataloader_changed(self, dataloader_idx: int) -> bool:
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
return old_dataloader_idx != dataloader_idx

def reset_dataloader_idx_tracker(self) -> None:
self._current_eval_dataloader_idx = None

def disable(self) -> None:
"""You should provide a way to disable the progress bar."""
Expand Down
81 changes: 53 additions & 28 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,6 @@ def is_enabled(self) -> bool:
def is_disabled(self) -> bool:
return not self.is_enabled

@property
def sanity_check_description(self) -> str:
return "Validation Sanity Check"

@property
def validation_description(self) -> str:
return "Validation"

@property
def test_description(self) -> str:
return "Testing"

@property
def predict_description(self) -> str:
return "Predicting"

def _update_for_light_colab_theme(self) -> None:
if _detect_light_colab_theme():
attributes = ["description", "batch_progress", "metrics"]
Expand Down Expand Up @@ -354,13 +338,28 @@ def on_train_epoch_start(self, trainer, pl_module):
)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx):
return

if trainer.sanity_checking:
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
if self.val_sanity_progress_bar_id is not None:
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)

self.val_sanity_progress_bar_id = self._add_task(
self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False
)
else:
if self.val_progress_bar_id is not None:
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)

# TODO: remove old tasks when new onces are created
self.val_progress_bar_id = self._add_task(
self.total_val_batches, self.validation_description, visible=False
self.total_val_batches_current_dataloader, self.validation_description, visible=False
)

self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
Expand All @@ -387,13 +386,36 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.state.fn == "fit":
self._update_metrics(trainer, pl_module)
self.reset_dataloader_idx_tracker()

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.reset_dataloader_idx_tracker()

def on_test_epoch_start(self, trainer, pl_module):
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.reset_dataloader_idx_tracker()

def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx):
return

if self.test_progress_bar_id is not None:
self.progress.update(self.test_progress_bar_id, advance=0, visible=False)
self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description)
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx):
return

if self.predict_progress_bar_id is not None:
self.progress.update(self.predict_progress_bar_id, advance=0, visible=False)
self.predict_progress_bar_id = self._add_task(
self.total_predict_batches_current_dataloader, self.predict_description
)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
Expand All @@ -406,20 +428,23 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
# TODO: Use total val_processed here just like TQDM in a follow-up
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self._update(
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
Expand Down
Loading