Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TQDM progress bar tracking with multiple dataloaders #11657

Merged
merged 25 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `parallel_devices` property in `ParallelStrategy` to be lazy initialized ([#11572](https://github.com/PyTorchLightning/pytorch-lightning/pull/11572))


- Update `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657))
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved


- Sorted `SimpleProfiler(extended=False)` summary based on mean duration for each hook ([#11671](https://github.com/PyTorchLightning/pytorch-lightning/pull/11671))


Expand All @@ -304,6 +307,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))


### Deprecated

- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
Expand Down
44 changes: 33 additions & 11 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):

def __init__(self) -> None:
self._trainer: Optional["pl.Trainer"] = None
self._current_eval_dataloader_idx: Optional[int] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def trainer(self) -> "pl.Trainer":
Expand All @@ -70,8 +71,12 @@ def val_batch_idx(self) -> int:
Use this to update your progress bar.
"""
if self.trainer.state.fn == "fit":
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.fit_loop.epoch_loop.val_loop
else:
loop = self.trainer.validate_loop

current_batch_idx = loop.epoch_loop.batch_progress.current.processed
return current_batch_idx

@property
def test_batch_idx(self) -> int:
Expand Down Expand Up @@ -105,15 +110,11 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
assert self._current_eval_dataloader_idx is not None
if self.trainer.sanity_checking:
return sum(self.trainer.num_sanity_val_batches)

total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx]

return total_val_batches
return self.trainer.num_val_batches[self._current_eval_dataloader_idx]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def total_test_batches(self) -> Union[int, float]:
Expand All @@ -122,7 +123,8 @@ def total_test_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
of infinite size.
"""
return sum(self.trainer.num_test_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_test_batches[self._current_eval_dataloader_idx]

@property
def total_predict_batches(self) -> Union[int, float]:
Expand All @@ -131,7 +133,27 @@ def total_predict_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
"""
return sum(self.trainer.num_predict_batches)
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_predict_batches[self._current_eval_dataloader_idx]

@property
def total_val_batches_current_epoch(self) -> Union[int, float]:
assert self._trainer is not None
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._is_check_val_epoch() else 0

def has_dataloader_changed(self, dataloader_idx: int) -> bool:
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
return old_dataloader_idx != dataloader_idx

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._current_eval_dataloader_idx = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def disable(self) -> None:
"""You should provide a way to disable the progress bar."""
Expand Down
57 changes: 41 additions & 16 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def enable(self) -> None:

@property
def sanity_check_description(self) -> str:
return "Validation Sanity Check"
return "Sanity Checking"

@property
def validation_description(self) -> str:
Expand Down Expand Up @@ -326,7 +326,7 @@ def on_sanity_check_end(self, trainer, pl_module):

def on_train_epoch_start(self, trainer, pl_module):
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
total_val_batches = self.total_val_batches_current_epoch
if total_train_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
Expand All @@ -346,14 +346,27 @@ def on_train_epoch_start(self, trainer, pl_module):
)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
if trainer.sanity_checking:
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
else:
self.val_progress_bar_id = self._add_task(
self.total_val_batches, self.validation_description, visible=False
)
self.refresh()
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
if trainer.sanity_checking:
if self.val_sanity_progress_bar_id is not None:
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)

self.val_sanity_progress_bar_id = self._add_task(
self.total_val_batches, self.sanity_check_description, visible=False
)
else:
if self.val_progress_bar_id is not None:
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)

# TODO: remove old tasks when new onces are created
self.val_progress_bar_id = self._add_task(
self.total_val_batches, self.validation_description, visible=False
)

self.refresh()
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
Expand All @@ -379,14 +392,25 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.state.fn == "fit":
self._update_metrics(trainer, pl_module)
super().on_validation_end(trainer, pl_module)

def on_test_epoch_start(self, trainer, pl_module):
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()
def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
if self.test_progress_bar_id is not None:
self.progress.update(self.test_progress_bar_id, advance=0, visible=False)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def on_predict_epoch_start(self, trainer, pl_module):
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()
def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
if self.predict_progress_bar_id is not None:
self.progress.update(self.predict_progress_bar_id, advance=0, visible=False)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
Expand All @@ -402,6 +426,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
# TODO: Fix this in a follow-up
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self.refresh()
Expand Down
95 changes: 63 additions & 32 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()}

@property
def sanity_check_description(self) -> str:
return "Sanity Checking"

@property
def train_description(self) -> str:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return "Training"

@property
def validation_description(self) -> str:
return "Validation"

@property
def test_description(self) -> str:
return "Testing"

@property
def predict_description(self) -> str:
return "Predicting"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def main_progress_bar(self) -> _tqdm:
if self._main_progress_bar is None:
Expand Down Expand Up @@ -173,10 +193,7 @@ def is_disabled(self) -> bool:

@property
def _val_processed(self) -> int:
if self.trainer.state.fn == "fit":
# use total in case validation runs more than once per training epoch
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

def disable(self) -> None:
self._enabled = False
Expand All @@ -187,7 +204,7 @@ def enable(self) -> None:
def init_sanity_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for the validation sanity run."""
bar = Tqdm(
desc="Validation sanity check",
desc=self.sanity_check_description,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=False,
Expand All @@ -199,7 +216,7 @@ def init_sanity_tqdm(self) -> Tqdm:
def init_train_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for training."""
bar = Tqdm(
desc="Training",
desc=self.train_description,
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
Expand All @@ -213,7 +230,7 @@ def init_train_tqdm(self) -> Tqdm:
def init_predict_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for predicting."""
bar = Tqdm(
desc="Predicting",
desc=self.predict_description,
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
Expand All @@ -229,7 +246,7 @@ def init_validation_tqdm(self) -> Tqdm:
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.trainer.state.fn != "validate"
bar = Tqdm(
desc="Validating",
desc=self.validation_description,
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=not has_main_bar,
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -256,14 +273,15 @@ def on_sanity_check_start(self, *_: Any) -> None:

def on_sanity_check_end(self, *_: Any) -> None:
self.main_progress_bar.close()
self.main_progress_bar = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.val_progress_bar.close()

def on_train_start(self, *_: Any) -> None:
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
total_val_batches = self.total_val_batches_current_epoch
if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
Expand All @@ -273,63 +291,76 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
if self._should_update(self.train_batch_idx):
if self._should_update(self.train_batch_idx, self.total_train_batches):
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not self.main_progress_bar.disable:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, *_: Any) -> None:
self.main_progress_bar.close()

def on_validation_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if trainer.sanity_checking:
self.val_progress_bar.total = sum(trainer.num_sanity_val_batches)
else:
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.sanity_checking:
self.val_progress_bar = self.init_validation_tqdm()

def on_validation_batch_start(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
self.val_progress_bar.total = convert_inf(self.total_val_batches)
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx):
if self._should_update(self.val_batch_idx, self.total_val_batches):
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == "fit":
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

def on_validation_epoch_end(self, *_: Any) -> None:
_update_n(self.val_progress_bar, self._val_processed)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()
super().on_validation_end(trainer, pl_module)

def on_test_start(self, *_: Any) -> None:
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_start(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
self.test_progress_bar.total = convert_inf(self.total_test_batches)
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx):
if self._should_update(self.test_batch_idx, self.total_test_batches):
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_epoch_end(self, *_: Any) -> None:
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, *_: Any) -> None:
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
super().on_test_end(trainer, pl_module)

def on_predict_epoch_start(self, *_: Any) -> None:
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if self.has_dataloader_changed(dataloader_idx):
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx):
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, *_: Any) -> None:
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
super().on_predict_end(trainer, pl_module)

def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
active_progress_bar = None
Expand All @@ -347,8 +378,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
s = sep.join(map(str, args))
active_progress_bar.write(s, **kwargs)

def _should_update(self, idx: int) -> bool:
return self.refresh_rate > 0 and idx % self.refresh_rate == 0
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand Down
Loading