From 4d0015d6f5eaa1b05f1ae80971becfc38e0c836a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 14:55:59 +0530 Subject: [PATCH 1/6] Run main progress bar independent of val progress bar --- .../callbacks/progress/tqdm_progress.py | 21 +++++++++++-------- tests/callbacks/test_tqdm_progress_bar.py | 6 +++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 5d5b17bfb9922..f527f53135c19 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -268,8 +268,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - if self._should_update(self.train_batch_idx, self.total_train_batches): - _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) + current = self.train_batch_idx + self._val_processed + if self._should_update(current, self.main_progress_bar.total): + _update_n(self.main_progress_bar, current) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -294,10 +295,12 @@ def on_validation_batch_start( self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader): + if self._should_update(self.val_batch_idx, self.val_progress_bar.total): _update_n(self.val_progress_bar, self.val_batch_idx) - if trainer.state.fn == "fit": - _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) + + current = self.train_batch_idx + self._val_processed + if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): + _update_n(self.main_progress_bar, current) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": @@ -318,7 +321,7 @@ def on_test_batch_start( self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader): + if self._should_update(self.test_batch_idx, self.test_progress_bar.total): _update_n(self.test_progress_bar, self.test_batch_idx) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -338,7 +341,7 @@ def on_predict_batch_start( self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader): + if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): _update_n(self.predict_progress_bar, self.predict_batch_idx) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -361,8 +364,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: s = sep.join(map(str, args)) active_progress_bar.write(s, **kwargs) - def _should_update(self, current: int, total: Union[int, float]) -> bool: - return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total) + def _should_update(self, current: int, total: int) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) @staticmethod def _resolve_refresh_rate(refresh_rate: int) -> int: diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index dcda45b63e499..ad2ebe1fe79f1 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -347,10 +347,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], [0, 0, 3, None, None], [1, 0, 3, [1], None], - [1, 1, 3, [1, 2], [1]], + [1, 1, 3, [2], [1]], [5, 0, 3, [3, 5], None], - [5, 2, 3, [3, 5, 7], [2]], - [5, 2, 6, [5, 7], [2]], + [5, 2, 3, [3, 6, 7], [2]], + [5, 2, 6, [6, 7], [2]], ], ) def test_main_progress_bar_update_amount( From 9b8bdf3e5de62c0c728a5bc4a7d5ae2a8b2eff69 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 15:44:54 +0530 Subject: [PATCH 2/6] add better test --- CHANGELOG.md | 2 +- tests/callbacks/test_tqdm_progress_bar.py | 59 ++++++++++++++++++++--- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 006d3e33862a9..6f52fb20e4624 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,7 +79,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Run main progress bar independent of val progress bar in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563)) - diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index ad2ebe1fe79f1..1edfa1dc43a32 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os import pickle import sys @@ -549,16 +550,59 @@ def test_tqdm_progress_bar_can_be_pickled(): pickle.dumps(bar) -@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( - ["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"], - [(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)], + ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], + [(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])], ) def test_progress_bar_max_val_check_interval( - tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval + tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates ): + limit_batches = 7 + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=1, + enable_model_summary=False, + val_check_interval=val_check_interval, + limit_train_batches=limit_batches, + limit_val_batches=limit_batches, + callbacks=TQDMProgressBar(refresh_rate=3), + ) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + + pbar = trainer.progress_bar_callback + assert pbar.main_progress_bar.n_values == main_progress_bar_updates + assert pbar.val_progress_bar.n_values == val_progress_bar_updates + + val_check_batch = ( + max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval + ) + assert trainer.val_check_batch == val_check_batch + val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) + pbar_callback = trainer.progress_bar_callback + total_val_batches = limit_batches * val_checks_per_epoch + + assert pbar_callback.val_progress_bar.n == limit_batches + assert pbar_callback.val_progress_bar.total == limit_batches + assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches + assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches + assert pbar_callback.is_enabled + + +@RunIf(standalone=True) +@pytest.mark.parametrize( + "val_check_interval", + [0.2, 0.5], +) +def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): world_size = 2 - train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size) + total_train_samples = 16 + train_batch_size = 4 + total_val_samples = 2 + val_batch_size = 1 + train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size) val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size) model = BoringModel() @@ -568,7 +612,6 @@ def test_progress_bar_max_val_check_interval( max_epochs=1, enable_model_summary=False, val_check_interval=val_check_interval, - accelerator="gpu", devices=world_size, strategy="ddp", ) @@ -585,8 +628,8 @@ def test_progress_bar_max_val_check_interval( assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches total_val_batches = total_val_batches * val_checks_per_epoch - assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches - assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches + assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size + assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size assert pbar_callback.is_enabled From ec53d5855059bc05b7889308c24294d6e3dde0ef Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 15:50:34 +0530 Subject: [PATCH 3/6] revert test --- tests/callbacks/test_tqdm_progress_bar.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 1edfa1dc43a32..53bbef0017030 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -591,7 +591,7 @@ def test_progress_bar_max_val_check_interval( assert pbar_callback.is_enabled -@RunIf(standalone=True) +@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( "val_check_interval", [0.2, 0.5], @@ -613,6 +613,7 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): enable_model_summary=False, val_check_interval=val_check_interval, devices=world_size, + accelerator="gpu", strategy="ddp", ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) From 7f380be3582f54a005878b1d0dbca6448a97ee1f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 15:51:26 +0530 Subject: [PATCH 4/6] code format --- tests/callbacks/test_tqdm_progress_bar.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 53bbef0017030..0ab086bf55f90 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -592,10 +592,7 @@ def test_progress_bar_max_val_check_interval( @RunIf(min_gpus=2, standalone=True) -@pytest.mark.parametrize( - "val_check_interval", - [0.2, 0.5], -) +@pytest.mark.parametrize("val_check_interval", [0.2, 0.5]) def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): world_size = 2 total_train_samples = 16 From 8991b6bc858e9d09893a4c1616e6c4300298a67a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 15:51:57 +0530 Subject: [PATCH 5/6] code format --- tests/callbacks/test_tqdm_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 0ab086bf55f90..0ec38648a674d 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -609,8 +609,8 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): max_epochs=1, enable_model_summary=False, val_check_interval=val_check_interval, - devices=world_size, accelerator="gpu", + devices=world_size, strategy="ddp", ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) From 6051014d13627dca66c5c2c4f388eaeaeb1745cc Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 11 Apr 2022 14:45:24 +0530 Subject: [PATCH 6/6] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cfd61b164771..2e26bba634d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Run main progress bar independent of val progress bar in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563)) +- Run main progress bar updates independent of val progress bar updates in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563)) - Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))