From f4883d6ead6e95ba767b64da1164ae1cd3cb3864 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 11 Apr 2022 19:08:49 +0530 Subject: [PATCH] Run main progress bar independent of val progress bar in `TQDMProgressBar` (#12563) Co-authored-by: carmocca --- CHANGELOG.md | 3 + .../callbacks/progress/tqdm_progress.py | 21 ++++--- tests/callbacks/test_tqdm_progress_bar.py | 61 ++++++++++++++++--- 3 files changed, 66 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a2a83ab54d06..5a31916976292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Run main progress bar updates independent of val progress bar updates in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563)) + + - Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452)) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index dcca487620584..4ce2964588498 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -263,8 +263,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - if self._should_update(self.train_batch_idx, self.total_train_batches): - _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) + current = self.train_batch_idx + self._val_processed + if self._should_update(current, self.main_progress_bar.total): + _update_n(self.main_progress_bar, current) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -289,10 +290,12 @@ def on_validation_batch_start( self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader): + if self._should_update(self.val_batch_idx, self.val_progress_bar.total): _update_n(self.val_progress_bar, self.val_batch_idx) - if trainer.state.fn == "fit": - _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) + + current = self.train_batch_idx + self._val_processed + if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): + _update_n(self.main_progress_bar, current) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": @@ -313,7 +316,7 @@ def on_test_batch_start( self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader): + if self._should_update(self.test_batch_idx, self.test_progress_bar.total): _update_n(self.test_progress_bar, self.test_batch_idx) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -333,7 +336,7 @@ def on_predict_batch_start( self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader): + if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): _update_n(self.predict_progress_bar, self.predict_batch_idx) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -356,8 +359,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: s = sep.join(map(str, args)) active_progress_bar.write(s, **kwargs) - def _should_update(self, current: int, total: Union[int, float]) -> bool: - return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total) + def _should_update(self, current: int, total: int) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) @staticmethod def _resolve_refresh_rate(refresh_rate: int) -> int: diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index dcda45b63e499..0ec38648a674d 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os import pickle import sys @@ -347,10 +348,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], [0, 0, 3, None, None], [1, 0, 3, [1], None], - [1, 1, 3, [1, 2], [1]], + [1, 1, 3, [2], [1]], [5, 0, 3, [3, 5], None], - [5, 2, 3, [3, 5, 7], [2]], - [5, 2, 6, [5, 7], [2]], + [5, 2, 3, [3, 6, 7], [2]], + [5, 2, 6, [6, 7], [2]], ], ) def test_main_progress_bar_update_amount( @@ -549,16 +550,56 @@ def test_tqdm_progress_bar_can_be_pickled(): pickle.dumps(bar) -@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( - ["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"], - [(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)], + ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], + [(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])], ) def test_progress_bar_max_val_check_interval( - tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval + tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates ): + limit_batches = 7 + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=1, + enable_model_summary=False, + val_check_interval=val_check_interval, + limit_train_batches=limit_batches, + limit_val_batches=limit_batches, + callbacks=TQDMProgressBar(refresh_rate=3), + ) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + + pbar = trainer.progress_bar_callback + assert pbar.main_progress_bar.n_values == main_progress_bar_updates + assert pbar.val_progress_bar.n_values == val_progress_bar_updates + + val_check_batch = ( + max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval + ) + assert trainer.val_check_batch == val_check_batch + val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) + pbar_callback = trainer.progress_bar_callback + total_val_batches = limit_batches * val_checks_per_epoch + + assert pbar_callback.val_progress_bar.n == limit_batches + assert pbar_callback.val_progress_bar.total == limit_batches + assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches + assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches + assert pbar_callback.is_enabled + + +@RunIf(min_gpus=2, standalone=True) +@pytest.mark.parametrize("val_check_interval", [0.2, 0.5]) +def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): world_size = 2 - train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size) + total_train_samples = 16 + train_batch_size = 4 + total_val_samples = 2 + val_batch_size = 1 + train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size) val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size) model = BoringModel() @@ -585,8 +626,8 @@ def test_progress_bar_max_val_check_interval( assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches total_val_batches = total_val_batches * val_checks_per_epoch - assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches - assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches + assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size + assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size assert pbar_callback.is_enabled