From 74ff0529f6f7896a9851d3a260efdde795ca1eea Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 31 Jan 2022 15:22:02 +0530 Subject: [PATCH 01/19] fix tqdm counter for multiple dataloaders --- pytorch_lightning/callbacks/progress/base.py | 53 ++++++++++- .../callbacks/progress/tqdm_progress.py | 51 ++++++----- .../loops/dataloader/evaluation_loop.py | 3 - pytorch_lightning/trainer/trainer.py | 3 + tests/callbacks/test_tqdm_progress_bar.py | 91 +++++++++++++------ 5 files changed, 144 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 4babd823e82d5..6846b3979609a 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -48,6 +48,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): def __init__(self) -> None: self._trainer: Optional["pl.Trainer"] = None + self._val_progress: Optional[int] = None + self._test_progress: Optional[int] = None + self._predict_progress: Optional[int] = None @property def trainer(self) -> "pl.Trainer": @@ -63,6 +66,34 @@ def train_batch_idx(self) -> int: """ return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed + def on_validation_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ): + if self._val_progress is None or batch_idx == 0: + max_batches = trainer.num_sanity_val_batches if trainer.sanity_checking else trainer.num_val_batches + self._val_progress = sum(max_batches[:dataloader_idx]) + + def on_test_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ): + if self._test_progress is None or batch_idx == 0: + self._test_progress = sum(trainer.num_test_batches[:dataloader_idx]) + + def on_predict_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ): + if self._predict_progress is None or batch_idx == 0: + self._predict_progress = sum(trainer.num_predict_batches[:dataloader_idx]) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._val_progress = None + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._test_progress = None + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._predict_progress = None + @property def val_batch_idx(self) -> int: """The number of batches processed during validation. @@ -70,8 +101,13 @@ def val_batch_idx(self) -> int: Use this to update your progress bar. """ if self.trainer.state.fn == "fit": - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed - return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed + loop = self.trainer.fit_loop.epoch_loop.val_loop + else: + loop = self.trainer.validate_loop + + current_batch_idx = loop.epoch_loop.batch_progress.current.processed + batch_idx = self._val_progress + current_batch_idx + return batch_idx @property def test_batch_idx(self) -> int: @@ -79,7 +115,10 @@ def test_batch_idx(self) -> int: Use this to update your progress bar. """ - return self.trainer.test_loop.epoch_loop.batch_progress.current.processed + loop = self.trainer.test_loop + current_batch_idx = loop.epoch_loop.batch_progress.current.processed + batch_idx = self._test_progress + current_batch_idx + return batch_idx @property def predict_batch_idx(self) -> int: @@ -87,7 +126,10 @@ def predict_batch_idx(self) -> int: Use this to update your progress bar. """ - return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed + loop = self.trainer.predict_loop + current_batch_idx = loop.epoch_loop.batch_progress.current.processed + batch_idx = self._predict_progress + current_batch_idx + return batch_idx @property def total_train_batches(self) -> Union[int, float]: @@ -105,6 +147,9 @@ def total_val_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ + if self.trainer.sanity_checking: + return sum(self.trainer.num_sanity_val_batches) + total_val_batches = 0 if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 95c5666957635..8c12439e7abae 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -173,10 +173,7 @@ def is_disabled(self) -> bool: @property def _val_processed(self) -> int: - if self.trainer.state.fn == "fit": - # use total in case validation runs more than once per training epoch - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed - return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed + return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed def disable(self) -> None: self._enabled = False @@ -227,12 +224,12 @@ def init_predict_tqdm(self) -> Tqdm: def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" # The main progress bar doesn't exist in `trainer.validate()` - has_main_bar = self._main_progress_bar is not None + has_main_bar = not self.trainer.state.fn == "validate" bar = Tqdm( desc="Validating", position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, - leave=False, + leave=not has_main_bar, dynamic_ncols=True, file=sys.stdout, ) @@ -256,6 +253,7 @@ def on_sanity_check_start(self, *_: Any) -> None: def on_sanity_check_end(self, *_: Any) -> None: self.main_progress_bar.close() + self.main_progress_bar = None self.val_progress_bar.close() def on_train_start(self, *_: Any) -> None: @@ -273,63 +271,68 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - if self._should_update(self.train_batch_idx): + if self._should_update(self.train_batch_idx, self.total_train_batches): _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) if not self.main_progress_bar.disable: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_end(self, *_: Any) -> None: self.main_progress_bar.close() - def on_validation_start(self, trainer: "pl.Trainer", *_: Any) -> None: + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.sanity_checking: self.val_progress_bar.total = sum(trainer.num_sanity_val_batches) else: self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) + def on_validation_batch_start(self, *args: Any, **kwargs: Any): + return super().on_validation_batch_start(*args, **kwargs) + + def on_test_batch_start(self, *args: Any, **kwargs: Any): + return super().on_test_batch_start(*args, **kwargs) + + def on_predict_batch_start(self, *args: Any, **kwargs: Any): + return super().on_predict_batch_start(*args, **kwargs) + def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx): + if self._should_update(self.val_batch_idx, self.total_val_batches): _update_n(self.val_progress_bar, self.val_batch_idx) if trainer.state.fn == "fit": _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) - def on_validation_epoch_end(self, *_: Any) -> None: - _update_n(self.val_progress_bar, self._val_processed) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() + super().on_validation_end(trainer, pl_module) - def on_test_start(self, *_: Any) -> None: + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx): + if self._should_update(self.test_batch_idx, self.total_test_batches): _update_n(self.test_progress_bar, self.test_batch_idx) - def on_test_epoch_end(self, *_: Any) -> None: - _update_n(self.test_progress_bar, self.test_batch_idx) - - def on_test_end(self, *_: Any) -> None: + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() + super().on_test_end(trainer, pl_module) - def on_predict_epoch_start(self, *_: Any) -> None: + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm() self.predict_progress_bar.total = convert_inf(self.total_predict_batches) def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx): + if self._should_update(self.predict_batch_idx, self.total_predict_batches): _update_n(self.predict_progress_bar, self.predict_batch_idx) - def on_predict_end(self, *_: Any) -> None: + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() + super().on_predict_end(trainer, pl_module) def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None @@ -347,8 +350,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: s = sep.join(map(str, args)) active_progress_bar.write(s, **kwargs) - def _should_update(self, idx: int) -> bool: - return self.refresh_rate > 0 and idx % self.refresh_rate == 0 + def _should_update(self, current: int, total: Union[int, float]) -> bool: + return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total) @staticmethod def _resolve_refresh_rate(refresh_rate: int) -> int: diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 1e0b30cab03c7..076fe1b52c96b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -178,9 +178,6 @@ def _get_max_batches(self) -> List[int]: max_batches = self.trainer.num_test_batches else: if self.trainer.sanity_checking: - self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches - ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac01227fd00ac..b60e168f2cfe8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1349,6 +1349,9 @@ def _run_sanity_check(self) -> None: # reload dataloaders val_loop._reload_evaluation_dataloaders() + self.num_sanity_val_batches = [ + min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches + ] # run eval step with torch.no_grad(): diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index e484e1cb5b32f..cb7f6ac2f02e9 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -75,55 +75,97 @@ def test_tqdm_progress_bar_misconfiguration(): Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False) -def test_tqdm_progress_bar_totals(tmpdir): +@pytest.mark.parametrize("num_dl", [1, 2]) +def test_tqdm_progress_bar_totals(tmpdir, num_dl): """Test that the progress finishes with the correct total steps processed.""" - model = BoringModel() + class CustomModel(BoringModel): + def _get_dataloaders(self): + dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] + return dls[0] if num_dl == 1 else dls - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - bar = trainer.progress_bar_callback + def val_dataloader(self): + return self._get_dataloaders() + + def test_dataloader(self): + return self._get_dataloaders() + + def predict_dataloader(self): + return self._get_dataloaders() + + def validation_step(self, batch, batch_idx, dataloader_idx=None): + return + + def test_step(self, batch, batch_idx, dataloader_idx=None): + return + def predict_step(self, batch, batch_idx, dataloader_idx=None): + return + + model = CustomModel() + model.validation_epoch_end = None + model.test_epoch_end = None + + # check the sanity dataloaders + num_sanity_val_steps = 4 + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps + ) + bar = trainer.progress_bar_callback trainer.fit(model) + expected_sanity_steps = num_sanity_val_steps * num_dl + assert not bar.val_progress_bar.leave + assert sum(trainer.num_sanity_val_batches) == expected_sanity_steps + assert bar.val_progress_bar.total == expected_sanity_steps + assert bar.val_progress_bar.n == expected_sanity_steps - # check main progress bar total + # fit + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + bar = trainer.progress_bar_callback + trainer.fit(model) n = bar.total_train_batches m = bar.total_val_batches + assert trainer.num_training_batches == n assert len(trainer.train_dataloader) == n assert bar.main_progress_bar.total == n + m + assert bar.main_progress_bar.leave - # check val progress bar total - assert sum(len(loader) for loader in trainer.val_dataloaders) == m - assert bar.val_progress_bar.total == m - - # main progress bar should have reached the end (train batches + val batches) + assert sum(trainer.num_val_batches) == m assert bar.main_progress_bar.n == n + m assert bar.train_batch_idx == n - # val progress bar should have reached the end + # check val progress bar total + assert bar.val_progress_bar.total == m assert bar.val_progress_bar.n == m - assert bar.val_batch_idx == m + assert not bar.val_progress_bar.leave # check that the test progress bar is off assert 0 == bar.total_test_batches with pytest.raises(TypeError, match="test_progress_bar` .* not been set"): assert bar.test_progress_bar is None + # validate trainer.validate(model) - + assert bar.val_progress_bar.leave + assert sum(trainer.num_val_batches) == m assert bar.val_progress_bar.total == m assert bar.val_progress_bar.n == m - assert bar.val_batch_idx == m + # test trainer.test(model) - - # check test progress bar total + assert bar.test_progress_bar.leave k = bar.total_test_batches - assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert sum(trainer.num_test_batches) == k assert bar.test_progress_bar.total == k + assert bar.test_progress_bar.n == k - # test progress bar should have reached the end + # predict + trainer.predict(model) + assert bar.predict_progress_bar.leave + k = bar.total_predict_batches + assert sum(trainer.num_predict_batches) == k + assert bar.test_progress_bar.total == k assert bar.test_progress_bar.n == k - assert bar.test_batch_idx == k def test_tqdm_progress_bar_fast_dev_run(tmpdir): @@ -139,8 +181,7 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 1 == progress_bar.total_val_batches assert 1 == progress_bar.train_batch_idx - assert 1 == progress_bar.val_batch_idx - assert 0 == progress_bar.test_batch_idx + assert 1 == progress_bar.val_progress_bar.n # the main progress bar should display 2 batches (1 train, 1 val) assert 2 == progress_bar.main_progress_bar.total @@ -149,14 +190,12 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): trainer.validate(model) # the validation progress bar should display 1 batch - assert 1 == progress_bar.val_batch_idx assert 1 == progress_bar.val_progress_bar.total assert 1 == progress_bar.val_progress_bar.n trainer.test(model) # the test progress bar should display 1 batch - assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n @@ -298,10 +337,10 @@ def n(self, value): [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], [0, 0, 3, None, None], [1, 0, 3, [1], None], - [1, 1, 3, [2], [1]], + [1, 1, 3, [1, 2], [1]], [5, 0, 3, [3, 5], None], - [5, 2, 3, [3, 7], [2]], - [5, 2, 6, [7], [2]], + [5, 2, 3, [3, 5, 7], [2]], + [5, 2, 6, [5, 7], [2]], ], ) def test_main_progress_bar_update_amount( From 3596f1c412811a42a663581609e542b981d67e43 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 31 Jan 2022 16:02:59 +0530 Subject: [PATCH 02/19] chlog and mypy --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/progress/base.py | 6 +++--- pytorch_lightning/callbacks/progress/tqdm_progress.py | 9 --------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..98c04d2c1d023 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -457,6 +457,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) +- Fixed `TQDMProgressBar` counter when using multple validation dataloaders ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657)) + + ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 6846b3979609a..90d4b49ff4ff5 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -68,20 +68,20 @@ def train_batch_idx(self) -> int: def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ): + ) -> None: if self._val_progress is None or batch_idx == 0: max_batches = trainer.num_sanity_val_batches if trainer.sanity_checking else trainer.num_val_batches self._val_progress = sum(max_batches[:dataloader_idx]) def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ): + ) -> None: if self._test_progress is None or batch_idx == 0: self._test_progress = sum(trainer.num_test_batches[:dataloader_idx]) def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ): + ) -> None: if self._predict_progress is None or batch_idx == 0: self._predict_progress = sum(trainer.num_predict_batches[:dataloader_idx]) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 8c12439e7abae..98d4651a995f9 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -289,15 +289,6 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) - def on_validation_batch_start(self, *args: Any, **kwargs: Any): - return super().on_validation_batch_start(*args, **kwargs) - - def on_test_batch_start(self, *args: Any, **kwargs: Any): - return super().on_test_batch_start(*args, **kwargs) - - def on_predict_batch_start(self, *args: Any, **kwargs: Any): - return super().on_predict_batch_start(*args, **kwargs) - def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.total_val_batches): _update_n(self.val_progress_bar, self.val_batch_idx) From 406c224d82bec21574d43edc7557256fe7d543fe Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Feb 2022 03:14:00 +0530 Subject: [PATCH 03/19] sep progress bars for each dataloader --- pytorch_lightning/callbacks/progress/base.py | 72 +++----- .../callbacks/progress/rich_progress.py | 36 ++-- .../callbacks/progress/tqdm_progress.py | 63 +++++-- tests/callbacks/test_tqdm_progress_bar.py | 154 ++++++++++-------- 4 files changed, 186 insertions(+), 139 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 90d4b49ff4ff5..68ede43791dc4 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -48,9 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): def __init__(self) -> None: self._trainer: Optional["pl.Trainer"] = None - self._val_progress: Optional[int] = None - self._test_progress: Optional[int] = None - self._predict_progress: Optional[int] = None + self._current_eval_dataloader_idx: Optional[int] = None @property def trainer(self) -> "pl.Trainer": @@ -58,6 +56,20 @@ def trainer(self) -> "pl.Trainer": raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.") return self._trainer + def is_dataloader_changed(self, dataloader_idx: int) -> bool: + old_dataloader_idx = self._current_eval_dataloader_idx + self._current_eval_dataloader_idx = dataloader_idx + return old_dataloader_idx != dataloader_idx + + def on_validation_end(self, *_: Any) -> None: + self._current_eval_dataloader_idx = None + + def on_test_end(self, *_: Any) -> None: + self._current_eval_dataloader_idx = None + + def on_predict_end(self, *_: Any) -> None: + self._current_eval_dataloader_idx = None + @property def train_batch_idx(self) -> int: """The number of batches processed during training. @@ -66,34 +78,6 @@ def train_batch_idx(self) -> int: """ return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed - def on_validation_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: - if self._val_progress is None or batch_idx == 0: - max_batches = trainer.num_sanity_val_batches if trainer.sanity_checking else trainer.num_val_batches - self._val_progress = sum(max_batches[:dataloader_idx]) - - def on_test_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: - if self._test_progress is None or batch_idx == 0: - self._test_progress = sum(trainer.num_test_batches[:dataloader_idx]) - - def on_predict_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: - if self._predict_progress is None or batch_idx == 0: - self._predict_progress = sum(trainer.num_predict_batches[:dataloader_idx]) - - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._val_progress = None - - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._test_progress = None - - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._predict_progress = None - @property def val_batch_idx(self) -> int: """The number of batches processed during validation. @@ -106,8 +90,7 @@ def val_batch_idx(self) -> int: loop = self.trainer.validate_loop current_batch_idx = loop.epoch_loop.batch_progress.current.processed - batch_idx = self._val_progress + current_batch_idx - return batch_idx + return current_batch_idx @property def test_batch_idx(self) -> int: @@ -115,10 +98,7 @@ def test_batch_idx(self) -> int: Use this to update your progress bar. """ - loop = self.trainer.test_loop - current_batch_idx = loop.epoch_loop.batch_progress.current.processed - batch_idx = self._test_progress + current_batch_idx - return batch_idx + return self.trainer.test_loop.epoch_loop.batch_progress.current.processed @property def predict_batch_idx(self) -> int: @@ -126,10 +106,7 @@ def predict_batch_idx(self) -> int: Use this to update your progress bar. """ - loop = self.trainer.predict_loop - current_batch_idx = loop.epoch_loop.batch_progress.current.processed - batch_idx = self._predict_progress + current_batch_idx - return batch_idx + return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed @property def total_train_batches(self) -> Union[int, float]: @@ -148,14 +125,9 @@ def total_val_batches(self) -> Union[int, float]: dataloader is of infinite size. """ if self.trainer.sanity_checking: - return sum(self.trainer.num_sanity_val_batches) - - total_val_batches = 0 - if self.trainer.enable_validation: - is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 + return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx] - return total_val_batches + return self.trainer.num_val_batches[self._current_eval_dataloader_idx] @property def total_test_batches(self) -> Union[int, float]: @@ -164,7 +136,7 @@ def total_test_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ - return sum(self.trainer.num_test_batches) + return self.trainer.num_test_batches[self._current_eval_dataloader_idx] @property def total_predict_batches(self) -> Union[int, float]: @@ -173,7 +145,7 @@ def total_predict_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """ - return sum(self.trainer.num_predict_batches) + return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] def disable(self) -> None: """You should provide a way to disable the progress bar. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 570c6d7df669c..1f4f48bc03c2f 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -16,6 +16,7 @@ from datetime import timedelta from typing import Any, Dict, Optional, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities.imports import _RICH_AVAILABLE @@ -325,7 +326,7 @@ def on_sanity_check_end(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches + total_val_batches = sum(self.trainer.num_val_batches) if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -345,8 +346,13 @@ def on_train_epoch_start(self, trainer, pl_module): ) self.refresh() - def on_validation_epoch_start(self, trainer, pl_module): - if self.total_val_batches > 0: + def on_validation_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): + if self.val_progress_bar_id is not None: + self.progress.update(self.val_progress_bar_id, advance=self.refresh_rate, visible=False) + total_val_batches = self.total_val_batches if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"): # val can be checked multiple times per epoch @@ -369,17 +375,27 @@ def _update(self, progress_bar_id: int, current: int, total: int, visible: bool def _should_update(self, current: int, total: int) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - def on_validation_epoch_end(self, trainer, pl_module): + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if self.val_progress_bar_id is not None: self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) - def on_test_epoch_start(self, trainer, pl_module): - self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) - self.refresh() + def on_test_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): + if self.test_progress_bar_id is not None: + self.progress.update(self.test_progress_bar_id, advance=self.refresh_rate, visible=False) + self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.refresh() - def on_predict_epoch_start(self, trainer, pl_module): - self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) - self.refresh() + def on_predict_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): + if self.predict_progress_bar_id is not None: + self.progress.update(self.predict_progress_bar_id, advance=self.refresh_rate, visible=False) + self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 98d4651a995f9..f8141422a334b 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -115,6 +115,26 @@ def __getstate__(self) -> Dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} + @property + def sanity_check_description(self) -> str: + return "Validation Sanity Check" + + @property + def train_description(self) -> str: + return "Training" + + @property + def validation_description(self) -> str: + return "Validation" + + @property + def test_description(self) -> str: + return "Testing" + + @property + def predict_description(self) -> str: + return "Predicting" + @property def main_progress_bar(self) -> _tqdm: if self._main_progress_bar is None: @@ -184,7 +204,7 @@ def enable(self) -> None: def init_sanity_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for the validation sanity run.""" bar = Tqdm( - desc="Validation sanity check", + desc=self.sanity_check_description, position=(2 * self.process_position), disable=self.is_disabled, leave=False, @@ -196,7 +216,7 @@ def init_sanity_tqdm(self) -> Tqdm: def init_train_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for training.""" bar = Tqdm( - desc="Training", + desc=self.train_description, initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, @@ -210,7 +230,7 @@ def init_train_tqdm(self) -> Tqdm: def init_predict_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for predicting.""" bar = Tqdm( - desc="Predicting", + desc=self.predict_description, initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, @@ -226,7 +246,7 @@ def init_validation_tqdm(self) -> Tqdm: # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = not self.trainer.state.fn == "validate" bar = Tqdm( - desc="Validating", + desc=self.validation_description, position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=not has_main_bar, @@ -261,7 +281,7 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches + total_val_batches = sum(trainer.num_val_batches) if total_train_batches != float("inf") and total_val_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -283,11 +303,16 @@ def on_train_end(self, *_: Any) -> None: self.main_progress_bar.close() def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if trainer.sanity_checking: - self.val_progress_bar.total = sum(trainer.num_sanity_val_batches) - else: + if not trainer.sanity_checking: self.val_progress_bar = self.init_validation_tqdm() + + def on_validation_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): self.val_progress_bar.total = convert_inf(self.total_val_batches) + desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description + self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.total_val_batches): @@ -299,11 +324,17 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() - super().on_validation_end(trainer, pl_module) + super().on_validation_end() def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm() - self.test_progress_bar.total = convert_inf(self.total_test_batches) + + def on_test_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): + self.test_progress_bar.total = convert_inf(self.total_test_batches) + self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: if self._should_update(self.test_batch_idx, self.total_test_batches): @@ -311,11 +342,18 @@ def on_test_batch_end(self, *_: Any) -> None: def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() - super().on_test_end(trainer, pl_module) + super().on_test_end() def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm() - self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + super().on_predict_end() + + def on_predict_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if self.is_dataloader_changed(dataloader_idx): + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: if self._should_update(self.predict_batch_idx, self.total_predict_batches): @@ -323,7 +361,6 @@ def on_predict_batch_end(self, *_: Any) -> None: def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() - super().on_predict_end(trainer, pl_module) def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index cb7f6ac2f02e9..451351b3f5662 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -32,6 +32,44 @@ from tests.helpers.runif import RunIf +class MockTqdm(Tqdm): + def __init__(self, *args, **kwargs): + self.n_values = [] + self.total_values = [] + self.descriptions = [] + super().__init__(*args, **kwargs) + self.__n = 0 + self.__total = 0 + # again to reset additions from `super().__init__` + self.n_values = [] + self.total_values = [] + self.descriptions = [] + + @property + def n(self): + return self.__n + + @n.setter + def n(self, value): + self.__n = value + # track the changes in the `n` value + if not len(self.n_values) or value != self.n_values[-1]: + self.n_values.append(value) + + @property + def total(self): + return self.__total + + @total.setter + def total(self, value): + self.__total = value + self.total_values.append(value) + + def set_description(self, *args, **kwargs): + super().set_description(*args, **kwargs) + self.descriptions.append(self.desc) + + @pytest.mark.parametrize( "kwargs", [ @@ -112,60 +150,61 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps ) bar = trainer.progress_bar_callback - trainer.fit(model) - expected_sanity_steps = num_sanity_val_steps * num_dl + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + + expected_sanity_steps = [num_sanity_val_steps] * num_dl assert not bar.val_progress_bar.leave - assert sum(trainer.num_sanity_val_batches) == expected_sanity_steps - assert bar.val_progress_bar.total == expected_sanity_steps - assert bar.val_progress_bar.n == expected_sanity_steps + assert trainer.num_sanity_val_batches == expected_sanity_steps + assert bar.val_progress_bar.total_values == expected_sanity_steps + assert bar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl + assert bar.val_progress_bar.descriptions == [f"Validation Sanity Check DataLoader {i}: " for i in range(num_dl)] # fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) bar = trainer.progress_bar_callback - trainer.fit(model) - n = bar.total_train_batches - m = bar.total_val_batches - assert trainer.num_training_batches == n + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + + n = trainer.num_training_batches + m = trainer.num_val_batches assert len(trainer.train_dataloader) == n - assert bar.main_progress_bar.total == n + m + assert bar.main_progress_bar.total == n + sum(m) assert bar.main_progress_bar.leave - - assert sum(trainer.num_val_batches) == m - assert bar.main_progress_bar.n == n + m - assert bar.train_batch_idx == n + assert bar.main_progress_bar.n == n + sum(m) # check val progress bar total - assert bar.val_progress_bar.total == m - assert bar.val_progress_bar.n == m + assert bar.val_progress_bar.total_values == m + assert bar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert bar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] assert not bar.val_progress_bar.leave - # check that the test progress bar is off - assert 0 == bar.total_test_batches - with pytest.raises(TypeError, match="test_progress_bar` .* not been set"): - assert bar.test_progress_bar is None - # validate - trainer.validate(model) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.validate(model) assert bar.val_progress_bar.leave - assert sum(trainer.num_val_batches) == m - assert bar.val_progress_bar.total == m - assert bar.val_progress_bar.n == m + assert trainer.num_val_batches == m + assert bar.val_progress_bar.total_values == m + assert bar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert bar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] # test - trainer.test(model) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.test(model) assert bar.test_progress_bar.leave - k = bar.total_test_batches - assert sum(trainer.num_test_batches) == k - assert bar.test_progress_bar.total == k - assert bar.test_progress_bar.n == k + k = trainer.num_test_batches + assert bar.test_progress_bar.total_values == k + assert bar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert bar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] # predict - trainer.predict(model) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.predict(model) assert bar.predict_progress_bar.leave - k = bar.total_predict_batches - assert sum(trainer.num_predict_batches) == k - assert bar.test_progress_bar.total == k - assert bar.test_progress_bar.n == k + k = trainer.num_predict_batches + assert bar.predict_progress_bar.total_values == k + assert bar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert bar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] def test_tqdm_progress_bar_fast_dev_run(tmpdir): @@ -176,12 +215,9 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) progress_bar = trainer.progress_bar_callback - assert 1 == progress_bar.total_train_batches - # total val batches are known only after val dataloaders have reloaded - assert 1 == progress_bar.total_val_batches - assert 1 == progress_bar.train_batch_idx assert 1 == progress_bar.val_progress_bar.n + assert 1 == progress_bar.val_progress_bar.total # the main progress bar should display 2 batches (1 train, 1 val) assert 2 == progress_bar.main_progress_bar.total @@ -237,19 +273,25 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) - assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches - assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert ( + progress_bar.train_batches_seen + progress_bar.val_batches_seen + == 3 * progress_bar.main_progress_bar.total + trainer.num_sanity_val_steps + ) assert progress_bar.test_batches_seen == 0 trainer.validate(model) - assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches - assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert ( + progress_bar.train_batches_seen + progress_bar.val_batches_seen + == 3 * progress_bar.main_progress_bar.total + progress_bar.val_progress_bar.total + trainer.num_sanity_val_steps + ) assert progress_bar.test_batches_seen == 0 trainer.test(model) - assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches - assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps - assert progress_bar.test_batches_seen == progress_bar.total_test_batches + assert ( + progress_bar.train_batches_seen + progress_bar.val_batches_seen + == 3 * progress_bar.main_progress_bar.total + progress_bar.val_progress_bar.total + trainer.num_sanity_val_steps + ) + assert progress_bar.test_batches_seen == progress_bar.test_progress_bar.total @pytest.mark.parametrize("limit_val_batches", (0, 5)) @@ -311,26 +353,6 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 19 -class MockTqdm(Tqdm): - def __init__(self, *args, **kwargs): - self.n_values = [] - super().__init__(*args, **kwargs) - self.__n = 0 - # again to reset additions from `super().__init__` - self.n_values = [] - - @property - def n(self): - return self.__n - - @n.setter - def n(self, value): - self.__n = value - # track the changes in the `n` value - if not len(self.n_values) or value != self.n_values[-1]: - self.n_values.append(value) - - @pytest.mark.parametrize( "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ From 0c15eb16093dfbad9ac2228bb64880ca03bf73bb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Feb 2022 03:28:29 +0530 Subject: [PATCH 04/19] mypy --- pytorch_lightning/callbacks/progress/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 68ede43791dc4..cab7a1405c1f4 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -61,13 +61,13 @@ def is_dataloader_changed(self, dataloader_idx: int) -> bool: self._current_eval_dataloader_idx = dataloader_idx return old_dataloader_idx != dataloader_idx - def on_validation_end(self, *_: Any) -> None: + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._current_eval_dataloader_idx = None - def on_test_end(self, *_: Any) -> None: + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._current_eval_dataloader_idx = None - def on_predict_end(self, *_: Any) -> None: + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._current_eval_dataloader_idx = None @property @@ -124,6 +124,7 @@ def total_val_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ + assert self._current_eval_dataloader_idx is not None if self.trainer.sanity_checking: return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx] @@ -136,6 +137,7 @@ def total_test_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ + assert self._current_eval_dataloader_idx is not None return self.trainer.num_test_batches[self._current_eval_dataloader_idx] @property @@ -145,6 +147,7 @@ def total_predict_batches(self) -> Union[int, float]: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """ + assert self._current_eval_dataloader_idx is not None return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] def disable(self) -> None: From 4ac764183c54538279ed201fa08898c36c511f28 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Feb 2022 03:33:29 +0530 Subject: [PATCH 05/19] bug --- pytorch_lightning/callbacks/progress/tqdm_progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index f8141422a334b..33764a4a16616 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -324,7 +324,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() - super().on_validation_end() + super().on_validation_end(trainer, pl_module) def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm() @@ -342,11 +342,11 @@ def on_test_batch_end(self, *_: Any) -> None: def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() - super().on_test_end() + super().on_test_end(trainer, pl_module) def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm() - super().on_predict_end() + super().on_predict_end(trainer, pl_module) def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int From d1f1b11c57b32865e27af9eb4a5a56b2698d68a1 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Feb 2022 04:11:01 +0530 Subject: [PATCH 06/19] fix test --- pytorch_lightning/callbacks/progress/base.py | 10 ++++++++++ pytorch_lightning/callbacks/progress/tqdm_progress.py | 2 +- tests/trainer/flags/test_check_val_every_n_epoch.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index cab7a1405c1f4..a394c4763e4a8 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -150,6 +150,16 @@ def total_predict_batches(self) -> Union[int, float]: assert self._current_eval_dataloader_idx is not None return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] + @property + def num_val_batches(self) -> Union[int, float]: + if ( + self.trainer.enable_validation + and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + ): + return sum(self.trainer.num_val_batches) + + return 0 + def disable(self) -> None: """You should provide a way to disable the progress bar. diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 33764a4a16616..2af83e23d8765 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -281,7 +281,7 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: total_train_batches = self.total_train_batches - total_val_batches = sum(trainer.num_val_batches) + total_val_batches = self.num_val_batches if total_train_batches != float("inf") and total_val_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index 97c6ddf7803ab..b22b2824dc6c2 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -27,7 +27,7 @@ class TestModel(BoringModel): val_batches = [] def on_train_epoch_end(self, *args, **kwargs): - self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches) + self.val_batches.append(self.trainer.progress_bar_callback.num_val_batches) def on_validation_epoch_start(self) -> None: self.val_epoch_calls += 1 From 57df96aba5fc89027d14f5535790c969f45f286e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 1 Feb 2022 04:15:01 +0530 Subject: [PATCH 07/19] fix for rich --- pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1f4f48bc03c2f..2bdcdc904db15 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -326,7 +326,7 @@ def on_sanity_check_end(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches - total_val_batches = sum(self.trainer.num_val_batches) + total_val_batches = self.num_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch From 60baf5a15fa8f30c4122b24def5d76e53b2f704e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 16:52:21 +0530 Subject: [PATCH 08/19] small improvements --- CHANGELOG.md | 6 +-- pytorch_lightning/callbacks/progress/base.py | 38 ++++++++----------- .../callbacks/progress/rich_progress.py | 10 ++--- .../callbacks/progress/tqdm_progress.py | 14 +++---- .../loops/dataloader/evaluation_loop.py | 3 ++ .../loops/epoch/training_epoch_loop.py | 12 +++--- pytorch_lightning/trainer/trainer.py | 3 -- tests/callbacks/test_tqdm_progress_bar.py | 3 +- .../flags/test_check_val_every_n_epoch.py | 2 +- 9 files changed, 43 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98c04d2c1d023..4e32c70852fd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,6 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) +- Update `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103)) @@ -457,9 +460,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) -- Fixed `TQDMProgressBar` counter when using multple validation dataloaders ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657)) - - ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index a394c4763e4a8..99a18daf0d240 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -56,20 +56,6 @@ def trainer(self) -> "pl.Trainer": raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.") return self._trainer - def is_dataloader_changed(self, dataloader_idx: int) -> bool: - old_dataloader_idx = self._current_eval_dataloader_idx - self._current_eval_dataloader_idx = dataloader_idx - return old_dataloader_idx != dataloader_idx - - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._current_eval_dataloader_idx = None - - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._current_eval_dataloader_idx = None - - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._current_eval_dataloader_idx = None - @property def train_batch_idx(self) -> int: """The number of batches processed during training. @@ -151,14 +137,22 @@ def total_predict_batches(self) -> Union[int, float]: return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] @property - def num_val_batches(self) -> Union[int, float]: - if ( - self.trainer.enable_validation - and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - ): - return sum(self.trainer.num_val_batches) - - return 0 + def total_val_batches_current_epoch(self) -> Union[int, float]: + return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._is_check_val_epoch() else 0 + + def has_dataloader_changed(self, dataloader_idx: int) -> bool: + old_dataloader_idx = self._current_eval_dataloader_idx + self._current_eval_dataloader_idx = dataloader_idx + return old_dataloader_idx != dataloader_idx + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._current_eval_dataloader_idx = None + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._current_eval_dataloader_idx = None + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._current_eval_dataloader_idx = None def disable(self) -> None: """You should provide a way to disable the progress bar. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 2bdcdc904db15..f6eb30e1882e3 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -267,7 +267,7 @@ def enable(self) -> None: @property def sanity_check_description(self) -> str: - return "Validation Sanity Check" + return "Sanity Checking" @property def validation_description(self) -> str: @@ -326,7 +326,7 @@ def on_sanity_check_end(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches - total_val_batches = self.num_val_batches + total_val_batches = self.total_val_batches_current_epoch if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -349,7 +349,7 @@ def on_train_epoch_start(self, trainer, pl_module): def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): if self.val_progress_bar_id is not None: self.progress.update(self.val_progress_bar_id, advance=self.refresh_rate, visible=False) @@ -382,7 +382,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): if self.test_progress_bar_id is not None: self.progress.update(self.test_progress_bar_id, advance=self.refresh_rate, visible=False) self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) @@ -391,7 +391,7 @@ def on_test_batch_start( def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): if self.predict_progress_bar_id is not None: self.progress.update(self.predict_progress_bar_id, advance=self.refresh_rate, visible=False) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 2af83e23d8765..071428ba988f0 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -117,7 +117,7 @@ def __getstate__(self) -> Dict: @property def sanity_check_description(self) -> str: - return "Validation Sanity Check" + return "Sanity Checking" @property def train_description(self) -> str: @@ -244,12 +244,12 @@ def init_predict_tqdm(self) -> Tqdm: def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" # The main progress bar doesn't exist in `trainer.validate()` - has_main_bar = not self.trainer.state.fn == "validate" + has_main_bar = self._main_progress_bar is not None bar = Tqdm( desc=self.validation_description, position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, - leave=not has_main_bar, + leave=False, dynamic_ncols=True, file=sys.stdout, ) @@ -281,7 +281,7 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: total_train_batches = self.total_train_batches - total_val_batches = self.num_val_batches + total_val_batches = self.total_val_batches_current_epoch if total_train_batches != float("inf") and total_val_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -309,7 +309,7 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): self.val_progress_bar.total = convert_inf(self.total_val_batches) desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") @@ -332,7 +332,7 @@ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): self.test_progress_bar.total = convert_inf(self.total_test_batches) self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") @@ -351,7 +351,7 @@ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.is_dataloader_changed(dataloader_idx): + if self.has_dataloader_changed(dataloader_idx): self.predict_progress_bar.total = convert_inf(self.total_predict_batches) self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 076fe1b52c96b..1e0b30cab03c7 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -178,6 +178,9 @@ def _get_max_batches(self) -> List[int]: max_batches = self.trainer.num_test_batches else: if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches + ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index b23608e0efd8a..4cf8eedc8e80e 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -464,13 +464,15 @@ def _get_monitor_value(self, key: str) -> Any: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) + def _is_check_val_epoch(self): + return ( + self.trainer.enable_validation + and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + ) + def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """Decide if we should run validation.""" - if not self.trainer.enable_validation: - return False - - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - if not is_val_check_epoch: + if not self._is_check_val_epoch(): return False # val_check_batch is inf for iterable datasets with no length defined diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b60e168f2cfe8..ac01227fd00ac 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1349,9 +1349,6 @@ def _run_sanity_check(self) -> None: # reload dataloaders val_loop._reload_evaluation_dataloaders() - self.num_sanity_val_batches = [ - min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches - ] # run eval step with torch.no_grad(): diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 451351b3f5662..2d8bc3bc47034 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -158,7 +158,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert trainer.num_sanity_val_batches == expected_sanity_steps assert bar.val_progress_bar.total_values == expected_sanity_steps assert bar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl - assert bar.val_progress_bar.descriptions == [f"Validation Sanity Check DataLoader {i}: " for i in range(num_dl)] + assert bar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) @@ -182,7 +182,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): # validate with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.validate(model) - assert bar.val_progress_bar.leave assert trainer.num_val_batches == m assert bar.val_progress_bar.total_values == m assert bar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index b22b2824dc6c2..a29be294b835a 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -27,7 +27,7 @@ class TestModel(BoringModel): val_batches = [] def on_train_epoch_end(self, *args, **kwargs): - self.val_batches.append(self.trainer.progress_bar_callback.num_val_batches) + self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches_current_epoch) def on_validation_epoch_start(self) -> None: self.val_epoch_calls += 1 From 398d475db2788c5c53051d6f42f3679ca72caba7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Feb 2022 14:05:32 +0000 Subject: [PATCH 09/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/progress/rich_progress.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 63ae8c584f9a1..a99943bdb5bac 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -353,7 +353,9 @@ def on_validation_batch_start( if self.val_progress_bar_id is not None: self.progress.update(self.val_progress_bar_id, advance=self.refresh_rate, visible=False) - self.val_progress_bar_id = self._add_task(self.total_val_batches, self.validation_description, visible=False) + self.val_progress_bar_id = self._add_task( + self.total_val_batches, self.validation_description, visible=False + ) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: From 02c26cf1017f95467d47076af3c50c13e73dfa70 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 01:22:15 +0530 Subject: [PATCH 10/19] fix rich progress bar from conflict --- .../callbacks/progress/rich_progress.py | 26 ++++++++++++++----- tests/callbacks/test_tqdm_progress_bar.py | 1 - 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index a99943bdb5bac..e0c871afe1621 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -350,12 +350,22 @@ def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: if self.has_dataloader_changed(dataloader_idx): - if self.val_progress_bar_id is not None: - self.progress.update(self.val_progress_bar_id, advance=self.refresh_rate, visible=False) + if trainer.sanity_checking: + if self.val_sanity_progress_bar_id is not None: + self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) + + self.val_sanity_progress_bar_id = self._add_task( + self.total_val_batches, self.sanity_check_description, visible=False + ) + else: + if self.val_progress_bar_id is not None: + self.progress.update(self.val_progress_bar_id, advance=0, visible=False) + + # TODO: remove old tasks when new onces are created + self.val_progress_bar_id = self._add_task( + self.total_val_batches, self.validation_description, visible=False + ) - self.val_progress_bar_id = self._add_task( - self.total_val_batches, self.validation_description, visible=False - ) self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: @@ -382,13 +392,14 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) + super().on_validation_end(trainer, pl_module) def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: if self.has_dataloader_changed(dataloader_idx): if self.test_progress_bar_id is not None: - self.progress.update(self.test_progress_bar_id, advance=self.refresh_rate, visible=False) + self.progress.update(self.test_progress_bar_id, advance=0, visible=False) self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) self.refresh() @@ -397,7 +408,7 @@ def on_predict_batch_start( ) -> None: if self.has_dataloader_changed(dataloader_idx): if self.predict_progress_bar_id is not None: - self.progress.update(self.predict_progress_bar_id, advance=self.refresh_rate, visible=False) + self.progress.update(self.predict_progress_bar_id, advance=0, visible=False) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) self.refresh() @@ -415,6 +426,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: + # TODO: Fix this in a follow-up self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh() diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index c947b90a1da61..9ffd9ff8679d2 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -159,7 +159,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert bar.val_progress_bar.total_values == expected_sanity_steps assert bar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl assert bar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] - assert bar.val_progress_bar.leave # fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) From e36d08819a330a0dbfb123e9606f20f7f563b30e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 21 Feb 2022 13:36:08 -0500 Subject: [PATCH 11/19] update gpu tests --- pytorch_lightning/callbacks/progress/base.py | 1 + tests/callbacks/test_tqdm_progress_bar.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 4cca7df1932d6..36c9fa76e7f3d 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -138,6 +138,7 @@ def total_predict_batches(self) -> Union[int, float]: @property def total_val_batches_current_epoch(self) -> Union[int, float]: + assert self._trainer is not None return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._is_check_val_epoch() else 0 def has_dataloader_changed(self, dataloader_idx: int) -> bool: diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 9ffd9ff8679d2..c35bb240d28b1 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -591,11 +591,13 @@ def test_progress_bar_max_val_check_interval( assert trainer.val_check_batch == val_check_batch val_checks_per_epoch = total_train_batches / val_check_batch total_val_batches = total_val_samples // (val_batch_size * world_size) - assert trainer.progress_bar_callback.total_train_batches == total_train_batches - assert trainer.progress_bar_callback.total_val_batches == total_val_batches + pbar_callback = trainer.progress_bar_callback + assert pbar_callback.val_progress_bar.n == total_val_batches + assert pbar_callback.val_progress_bar.total == total_val_batches total_val_batches = total_val_batches * val_checks_per_epoch - if trainer.is_global_zero: - assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches + assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches + assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches + assert pbar_callback.is_enabled == trainer.is_global_zero def test_get_progress_bar_metrics(tmpdir: str): From f8529792b7f0c341618f9061365d687b9ef6a9b4 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 14:08:18 +0530 Subject: [PATCH 12/19] move the call --- pytorch_lightning/callbacks/progress/tqdm_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index cb1c1c62ba677..c916443b10e2b 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -346,7 +346,6 @@ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm() - super().on_predict_end(trainer, pl_module) def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int @@ -361,6 +360,7 @@ def on_predict_batch_end(self, *_: Any) -> None: def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() + super().on_predict_end(trainer, pl_module) def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None From 6a40485964b0f7a76f999800cd0faa2d59957e53 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 23:04:26 +0530 Subject: [PATCH 13/19] improvements --- pytorch_lightning/callbacks/progress/base.py | 10 ++-------- pytorch_lightning/callbacks/progress/rich_progress.py | 10 ++++++++-- pytorch_lightning/callbacks/progress/tqdm_progress.py | 8 ++++---- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 36c9fa76e7f3d..e1ae0643e389d 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -139,20 +139,14 @@ def total_predict_batches(self) -> Union[int, float]: @property def total_val_batches_current_epoch(self) -> Union[int, float]: assert self._trainer is not None - return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._is_check_val_epoch() else 0 + return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0 def has_dataloader_changed(self, dataloader_idx: int) -> bool: old_dataloader_idx = self._current_eval_dataloader_idx self._current_eval_dataloader_idx = dataloader_idx return old_dataloader_idx != dataloader_idx - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._current_eval_dataloader_idx = None - - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._current_eval_dataloader_idx = None - - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def reset_dataloader_idx_tracker(self) -> None: self._current_eval_dataloader_idx = None def disable(self) -> None: diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index e0c871afe1621..0887cf6e2f0df 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -392,7 +392,13 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) - super().on_validation_end(trainer, pl_module) + self.reset_dataloader_idx_tracker() + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.reset_dataloader_idx_tracker() + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.reset_dataloader_idx_tracker() def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int @@ -426,7 +432,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - # TODO: Fix this in a follow-up + # TODO: Use total val_processed here just like TQDM in a follow-up self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh() diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index c916443b10e2b..1fcfeed0383e3 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -193,6 +193,7 @@ def is_disabled(self) -> bool: @property def _val_processed(self) -> int: + # use total in case validation runs more than once per training epoch return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed def disable(self) -> None: @@ -273,7 +274,6 @@ def on_sanity_check_start(self, *_: Any) -> None: def on_sanity_check_end(self, *_: Any) -> None: self.main_progress_bar.close() - self.main_progress_bar = None self.val_progress_bar.close() def on_train_start(self, *_: Any) -> None: @@ -324,7 +324,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() - super().on_validation_end(trainer, pl_module) + self.reset_dataloader_idx_tracker() def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm() @@ -342,7 +342,7 @@ def on_test_batch_end(self, *_: Any) -> None: def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() - super().on_test_end(trainer, pl_module) + self.reset_dataloader_idx_tracker() def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm() @@ -360,7 +360,7 @@ def on_predict_batch_end(self, *_: Any) -> None: def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() - super().on_predict_end(trainer, pl_module) + self.reset_dataloader_idx_tracker() def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e755bce6f8056..c931105262cfd 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -475,7 +475,7 @@ def _get_monitor_value(self, key: str) -> Any: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) - def _is_check_val_epoch(self): + def _should_check_val_epoch(self): return ( self.trainer.enable_validation and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 @@ -483,7 +483,7 @@ def _is_check_val_epoch(self): def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """Decide if we should run validation.""" - if not self._is_check_val_epoch(): + if not self._should_check_val_epoch(): return False # val_check_batch is inf for iterable datasets with no length defined From 58f492a3c5490d2a62121f960935231898dba956 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 28 Feb 2022 18:30:29 +0530 Subject: [PATCH 14/19] address Jirka comments --- .../callbacks/progress/rich_progress.py | 56 ++++++++++--------- .../callbacks/progress/tqdm_progress.py | 26 +++++---- tests/callbacks/test_tqdm_progress_bar.py | 3 +- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 0887cf6e2f0df..7f61758ae80cb 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -349,24 +349,26 @@ def on_train_epoch_start(self, trainer, pl_module): def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - if trainer.sanity_checking: - if self.val_sanity_progress_bar_id is not None: - self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) + if not self.has_dataloader_changed(dataloader_idx): + return - self.val_sanity_progress_bar_id = self._add_task( - self.total_val_batches, self.sanity_check_description, visible=False - ) - else: - if self.val_progress_bar_id is not None: - self.progress.update(self.val_progress_bar_id, advance=0, visible=False) + if trainer.sanity_checking: + if self.val_sanity_progress_bar_id is not None: + self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) - # TODO: remove old tasks when new onces are created - self.val_progress_bar_id = self._add_task( - self.total_val_batches, self.validation_description, visible=False - ) + self.val_sanity_progress_bar_id = self._add_task( + self.total_val_batches, self.sanity_check_description, visible=False + ) + else: + if self.val_progress_bar_id is not None: + self.progress.update(self.val_progress_bar_id, advance=0, visible=False) - self.refresh() + # TODO: remove old tasks when new onces are created + self.val_progress_bar_id = self._add_task( + self.total_val_batches, self.validation_description, visible=False + ) + + self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: @@ -403,20 +405,24 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - if self.test_progress_bar_id is not None: - self.progress.update(self.test_progress_bar_id, advance=0, visible=False) - self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) - self.refresh() + if not self.has_dataloader_changed(dataloader_idx): + return + + if self.test_progress_bar_id is not None: + self.progress.update(self.test_progress_bar_id, advance=0, visible=False) + self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.refresh() def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - if self.predict_progress_bar_id is not None: - self.progress.update(self.predict_progress_bar_id, advance=0, visible=False) - self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) - self.refresh() + if not self.has_dataloader_changed(dataloader_idx): + return + + if self.predict_progress_bar_id is not None: + self.progress.update(self.predict_progress_bar_id, advance=0, visible=False) + self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 1fcfeed0383e3..b37253f8d9419 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -309,10 +309,12 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - self.val_progress_bar.total = convert_inf(self.total_val_batches) - desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description - self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") + if not self.has_dataloader_changed(dataloader_idx): + return + + self.val_progress_bar.total = convert_inf(self.total_val_batches) + desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description + self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.total_val_batches): @@ -332,9 +334,11 @@ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - self.test_progress_bar.total = convert_inf(self.total_test_batches) - self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") + if not self.has_dataloader_changed(dataloader_idx): + return + + self.test_progress_bar.total = convert_inf(self.total_test_batches) + self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: if self._should_update(self.test_batch_idx, self.total_test_batches): @@ -350,9 +354,11 @@ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if self.has_dataloader_changed(dataloader_idx): - self.predict_progress_bar.total = convert_inf(self.total_predict_batches) - self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") + if not self.has_dataloader_changed(dataloader_idx): + return + + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: if self._should_update(self.predict_batch_idx, self.total_predict_batches): diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index c35bb240d28b1..4ced9191391dc 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -169,9 +169,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): n = trainer.num_training_batches m = trainer.num_val_batches assert len(trainer.train_dataloader) == n + # main progress bar should have reached the end (train batches + val batches) assert bar.main_progress_bar.total == n + sum(m) - assert bar.main_progress_bar.leave assert bar.main_progress_bar.n == n + sum(m) + assert bar.main_progress_bar.leave # check val progress bar total assert bar.val_progress_bar.total_values == m From b82d717e45d6c79dbfec5ba574c336cd70a4f581 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 7 Mar 2022 12:54:30 +0400 Subject: [PATCH 15/19] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e84526b7a625a..8cacaa1804bd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -302,7 +302,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `parallel_devices` property in `ParallelStrategy` to be lazy initialized ([#11572](https://github.com/PyTorchLightning/pytorch-lightning/pull/11572)) -- Update `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657)) +- Updated `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657)) - Sorted `SimpleProfiler(extended=False)` summary based on mean duration for each hook ([#11671](https://github.com/PyTorchLightning/pytorch-lightning/pull/11671)) From 3adacd5eb2f354ba7ec10ffc61a7f06da0173ce0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 7 Mar 2022 14:06:18 +0400 Subject: [PATCH 16/19] better property names --- pytorch_lightning/callbacks/progress/base.py | 19 +++++++++------ .../callbacks/progress/rich_progress.py | 24 +++++++++++-------- .../callbacks/progress/tqdm_progress.py | 14 +++++------ 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 1ed8252ea6aa2..3a32dead7e9bb 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -105,8 +105,8 @@ def total_train_batches(self) -> Union[int, float]: return self.trainer.num_training_batches @property - def total_val_batches(self) -> Union[int, float]: - """The total number of validation batches, which may change from epoch to epoch. + def total_val_batches_current_dataloader(self) -> Union[int, float]: + """The total number of validation batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. @@ -118,8 +118,8 @@ def total_val_batches(self) -> Union[int, float]: return self.trainer.num_val_batches[self._current_eval_dataloader_idx] @property - def total_test_batches(self) -> Union[int, float]: - """The total number of testing batches, which may change from epoch to epoch. + def total_test_batches_current_dataloader(self) -> Union[int, float]: + """The total number of testing batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. @@ -128,8 +128,8 @@ def total_test_batches(self) -> Union[int, float]: return self.trainer.num_test_batches[self._current_eval_dataloader_idx] @property - def total_predict_batches(self) -> Union[int, float]: - """The total number of prediction batches, which may change from epoch to epoch. + def total_predict_batches_current_dataloader(self) -> Union[int, float]: + """The total number of prediction batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. @@ -138,7 +138,12 @@ def total_predict_batches(self) -> Union[int, float]: return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] @property - def total_val_batches_current_epoch(self) -> Union[int, float]: + def total_val_batches(self) -> Union[int, float]: + """The total number of validation batches, which may change from epoch to epoch for all val dataloaders. + + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader + is of infinite size. + """ assert self._trainer is not None return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0 diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index f96828616a54a..0c25c2dfab3ac 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -334,7 +334,7 @@ def on_sanity_check_end(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module): total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches_current_epoch + total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -365,7 +365,7 @@ def on_validation_batch_start( self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.val_sanity_progress_bar_id = self._add_task( - self.total_val_batches, self.sanity_check_description, visible=False + self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False ) else: if self.val_progress_bar_id is not None: @@ -373,7 +373,7 @@ def on_validation_batch_start( # TODO: remove old tasks when new onces are created self.val_progress_bar_id = self._add_task( - self.total_val_batches, self.validation_description, visible=False + self.total_val_batches_current_dataloader, self.validation_description, visible=False ) self.refresh() @@ -418,7 +418,7 @@ def on_test_batch_start( if self.test_progress_bar_id is not None: self.progress.update(self.test_progress_bar_id, advance=0, visible=False) - self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description) self.refresh() def on_predict_batch_start( @@ -429,7 +429,9 @@ def on_predict_batch_start( if self.predict_progress_bar_id is not None: self.progress.update(self.predict_progress_bar_id, advance=0, visible=False) - self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.predict_progress_bar_id = self._add_task( + self.total_predict_batches_current_dataloader, self.predict_description + ) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): @@ -442,21 +444,23 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: # TODO: Use total val_processed here just like TQDM in a follow-up - self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) - self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) + self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader) self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) + self._update( + self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader + ) self.refresh() def _get_train_description(self, current_epoch: int) -> str: diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index b37253f8d9419..1e8059bee5841 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -281,7 +281,7 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches_current_epoch + total_val_batches = self.total_val_batches if total_train_batches != float("inf") and total_val_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch @@ -312,12 +312,12 @@ def on_validation_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.val_progress_bar.total = convert_inf(self.total_val_batches) + self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader) desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx, self.total_val_batches): + if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader): _update_n(self.val_progress_bar, self.val_batch_idx) if trainer.state.fn == "fit": _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) @@ -337,11 +337,11 @@ def on_test_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.test_progress_bar.total = convert_inf(self.total_test_batches) + self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader) self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx, self.total_test_batches): + if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader): _update_n(self.test_progress_bar, self.test_batch_idx) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -357,11 +357,11 @@ def on_predict_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader) self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx, self.total_predict_batches): + if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader): _update_n(self.predict_progress_bar, self.predict_batch_idx) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From bfd8127eaf43af421bebc74153f2b18014bdcf77 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 7 Mar 2022 14:08:54 +0400 Subject: [PATCH 17/19] add properties to base --- pytorch_lightning/callbacks/progress/base.py | 20 +++++++++++++++++++ .../callbacks/progress/rich_progress.py | 16 --------------- .../callbacks/progress/tqdm_progress.py | 20 ------------------- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 3a32dead7e9bb..234d62a68c308 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -57,6 +57,26 @@ def trainer(self) -> "pl.Trainer": raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.") return self._trainer + @property + def sanity_check_description(self) -> str: + return "Sanity Checking" + + @property + def train_description(self) -> str: + return "Training" + + @property + def validation_description(self) -> str: + return "Validation" + + @property + def test_description(self) -> str: + return "Testing" + + @property + def predict_description(self) -> str: + return "Predicting" + @property def train_batch_idx(self) -> int: """The number of batches processed during training. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 0c25c2dfab3ac..4d7c4b7864055 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -262,22 +262,6 @@ def is_enabled(self) -> bool: def is_disabled(self) -> bool: return not self.is_enabled - @property - def sanity_check_description(self) -> str: - return "Sanity Checking" - - @property - def validation_description(self) -> str: - return "Validation" - - @property - def test_description(self) -> str: - return "Testing" - - @property - def predict_description(self) -> str: - return "Predicting" - def _update_for_light_colab_theme(self) -> None: if _detect_light_colab_theme(): attributes = ["description", "batch_progress", "metrics"] diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 1e8059bee5841..19090a1efaf03 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -115,26 +115,6 @@ def __getstate__(self) -> Dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} - @property - def sanity_check_description(self) -> str: - return "Sanity Checking" - - @property - def train_description(self) -> str: - return "Training" - - @property - def validation_description(self) -> str: - return "Validation" - - @property - def test_description(self) -> str: - return "Testing" - - @property - def predict_description(self) -> str: - return "Predicting" - @property def main_progress_bar(self) -> _tqdm: if self._main_progress_bar is None: From f574bf52302acc00164b49919d3f6bcc53d716bb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 7 Mar 2022 17:36:10 +0400 Subject: [PATCH 18/19] update test --- tests/trainer/flags/test_check_val_every_n_epoch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index a29be294b835a..97c6ddf7803ab 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -27,7 +27,7 @@ class TestModel(BoringModel): val_batches = [] def on_train_epoch_end(self, *args, **kwargs): - self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches_current_epoch) + self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches) def on_validation_epoch_start(self) -> None: self.val_epoch_calls += 1 From 3daa4762618a2cdc8394a0170ac2f6b4ef7a1502 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 14 Mar 2022 10:59:20 +0100 Subject: [PATCH 19/19] pbar --- tests/callbacks/test_tqdm_progress_bar.py | 120 +++++++++++----------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 4ced9191391dc..7f9de366d01d4 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -149,20 +149,20 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps ) - bar = trainer.progress_bar_callback + pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) expected_sanity_steps = [num_sanity_val_steps] * num_dl - assert not bar.val_progress_bar.leave + assert not pbar.val_progress_bar.leave assert trainer.num_sanity_val_batches == expected_sanity_steps - assert bar.val_progress_bar.total_values == expected_sanity_steps - assert bar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl - assert bar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] + assert pbar.val_progress_bar.total_values == expected_sanity_steps + assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl + assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - bar = trainer.progress_bar_callback + pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) @@ -170,43 +170,43 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): m = trainer.num_val_batches assert len(trainer.train_dataloader) == n # main progress bar should have reached the end (train batches + val batches) - assert bar.main_progress_bar.total == n + sum(m) - assert bar.main_progress_bar.n == n + sum(m) - assert bar.main_progress_bar.leave + assert pbar.main_progress_bar.total == n + sum(m) + assert pbar.main_progress_bar.n == n + sum(m) + assert pbar.main_progress_bar.leave # check val progress bar total - assert bar.val_progress_bar.total_values == m - assert bar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl - assert bar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] - assert not bar.val_progress_bar.leave + assert pbar.val_progress_bar.total_values == m + assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] + assert not pbar.val_progress_bar.leave # validate with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.validate(model) assert trainer.num_val_batches == m - assert bar.val_progress_bar.total_values == m - assert bar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl - assert bar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] + assert pbar.val_progress_bar.total_values == m + assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] # test with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.test(model) - assert bar.test_progress_bar.leave + assert pbar.test_progress_bar.leave k = trainer.num_test_batches - assert bar.test_progress_bar.total_values == k - assert bar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl - assert bar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] - assert bar.test_progress_bar.leave + assert pbar.test_progress_bar.total_values == k + assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] + assert pbar.test_progress_bar.leave # predict with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.predict(model) - assert bar.predict_progress_bar.leave + assert pbar.predict_progress_bar.leave k = trainer.num_predict_batches - assert bar.predict_progress_bar.total_values == k - assert bar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl - assert bar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] - assert bar.predict_progress_bar.leave + assert pbar.predict_progress_bar.total_values == k + assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] + assert pbar.predict_progress_bar.leave def test_tqdm_progress_bar_fast_dev_run(tmpdir): @@ -216,26 +216,26 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) - progress_bar = trainer.progress_bar_callback + pbar = trainer.progress_bar_callback - assert 1 == progress_bar.val_progress_bar.n - assert 1 == progress_bar.val_progress_bar.total + assert 1 == pbar.val_progress_bar.n + assert 1 == pbar.val_progress_bar.total # the main progress bar should display 2 batches (1 train, 1 val) - assert 2 == progress_bar.main_progress_bar.total - assert 2 == progress_bar.main_progress_bar.n + assert 2 == pbar.main_progress_bar.total + assert 2 == pbar.main_progress_bar.n trainer.validate(model) # the validation progress bar should display 1 batch - assert 1 == progress_bar.val_progress_bar.total - assert 1 == progress_bar.val_progress_bar.n + assert 1 == pbar.val_progress_bar.total + assert 1 == pbar.val_progress_bar.n trainer.test(model) # the test progress bar should display 1 batch - assert 1 == progress_bar.test_progress_bar.total - assert 1 == progress_bar.test_progress_bar.n + assert 1 == pbar.test_progress_bar.total + assert 1 == pbar.test_progress_bar.n @pytest.mark.parametrize("refresh_rate", [0, 1, 50]) @@ -262,11 +262,11 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) self.test_batches_seen += 1 - progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + pbar = CurrentProgressBar(refresh_rate=refresh_rate) with pytest.deprecated_call(match=r"progress_bar_refresh_rate=101\)` is deprecated"): trainer = Trainer( default_root_dir=tmpdir, - callbacks=[progress_bar], + callbacks=[pbar], progress_bar_refresh_rate=101, # should not matter if custom callback provided limit_train_batches=1.0, num_sanity_val_steps=2, @@ -276,24 +276,24 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert ( - progress_bar.train_batches_seen + progress_bar.val_batches_seen - == 3 * progress_bar.main_progress_bar.total + trainer.num_sanity_val_steps + pbar.train_batches_seen + pbar.val_batches_seen + == 3 * pbar.main_progress_bar.total + trainer.num_sanity_val_steps ) - assert progress_bar.test_batches_seen == 0 + assert pbar.test_batches_seen == 0 trainer.validate(model) assert ( - progress_bar.train_batches_seen + progress_bar.val_batches_seen - == 3 * progress_bar.main_progress_bar.total + progress_bar.val_progress_bar.total + trainer.num_sanity_val_steps + pbar.train_batches_seen + pbar.val_batches_seen + == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps ) - assert progress_bar.test_batches_seen == 0 + assert pbar.test_batches_seen == 0 trainer.test(model) assert ( - progress_bar.train_batches_seen + progress_bar.val_batches_seen - == 3 * progress_bar.main_progress_bar.total + progress_bar.val_progress_bar.total + trainer.num_sanity_val_steps + pbar.train_batches_seen + pbar.val_batches_seen + == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps ) - assert progress_bar.test_batches_seen == progress_bar.test_progress_bar.total + assert pbar.test_batches_seen == pbar.test_progress_bar.total @pytest.mark.parametrize("limit_val_batches", (0, 5)) @@ -313,7 +313,7 @@ def on_validation_epoch_end(self, *args): super().on_validation_epoch_end(*args) model = BoringModel() - progress_bar = CurrentProgressBar() + pbar = CurrentProgressBar() num_sanity_val_steps = 2 trainer = Trainer( @@ -322,14 +322,14 @@ def on_validation_epoch_end(self, *args): num_sanity_val_steps=num_sanity_val_steps, limit_train_batches=1, limit_val_batches=limit_val_batches, - callbacks=[progress_bar], + callbacks=[pbar], logger=False, enable_checkpointing=False, ) trainer.fit(model) - assert progress_bar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) - assert progress_bar.val_pbar_total == limit_val_batches + assert pbar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) + assert pbar.val_pbar_total == limit_val_batches def test_tqdm_progress_bar_default_value(tmpdir): @@ -690,25 +690,25 @@ def test_step(self, batch, batch_idx): @mock.patch("pytorch_lightning.trainer.trainer.Trainer.is_global_zero", new_callable=PropertyMock, return_value=False) def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero): """Test that the progress bar is disabled when not in global rank zero.""" - progress_bar = TQDMProgressBar() + pbar = TQDMProgressBar() model = BoringModel() trainer = Trainer( - callbacks=[progress_bar], + callbacks=[pbar], fast_dev_run=True, ) - progress_bar.enable() + pbar.enable() trainer.fit(model) - assert progress_bar.is_disabled + assert pbar.is_disabled - progress_bar.enable() + pbar.enable() trainer.predict(model) - assert progress_bar.is_disabled + assert pbar.is_disabled - progress_bar.enable() + pbar.enable() trainer.validate(model) - assert progress_bar.is_disabled + assert pbar.is_disabled - progress_bar.enable() + pbar.enable() trainer.test(model) - assert progress_bar.is_disabled + assert pbar.is_disabled