diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c0501843f926..d3029b55b27cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -503,6 +503,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) +- Fixed an issue to avoid val bar disappear after `trainer.validate()` ([#11700](https://github.com/PyTorchLightning/pytorch-lightning/pull/11700)) + + - Fixed the mid-epoch warning call while resuming training ([#11556](https://github.com/PyTorchLightning/pytorch-lightning/pull/11556)) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 95c5666957635..bb80c22d3ff88 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -227,12 +227,12 @@ def init_predict_tqdm(self) -> Tqdm: def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" # The main progress bar doesn't exist in `trainer.validate()` - has_main_bar = self._main_progress_bar is not None + has_main_bar = self.trainer.state.fn != "validate" bar = Tqdm( desc="Validating", position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, - leave=False, + leave=not has_main_bar, dynamic_ncols=True, file=sys.stdout, ) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 9e52e1f9a14e1..7897a1be798bb 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -90,10 +90,12 @@ def test_tqdm_progress_bar_totals(tmpdir): m = bar.total_val_batches assert len(trainer.train_dataloader) == n assert bar.main_progress_bar.total == n + m + assert bar.main_progress_bar.leave # check val progress bar total assert sum(len(loader) for loader in trainer.val_dataloaders) == m assert bar.val_progress_bar.total == m + assert not bar.val_progress_bar.leave # main progress bar should have reached the end (train batches + val batches) assert bar.main_progress_bar.n == n + m @@ -113,6 +115,7 @@ def test_tqdm_progress_bar_totals(tmpdir): assert bar.val_progress_bar.total == m assert bar.val_progress_bar.n == m assert bar.val_batch_idx == m + assert bar.val_progress_bar.leave trainer.test(model) @@ -120,6 +123,7 @@ def test_tqdm_progress_bar_totals(tmpdir): k = bar.total_test_batches assert sum(len(loader) for loader in trainer.test_dataloaders) == k assert bar.test_progress_bar.total == k + assert bar.test_progress_bar.leave # test progress bar should have reached the end assert bar.test_progress_bar.n == k