Skip to content

Commit

Permalink
Fix to avoid val progress bar disappear after validate (#11700)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Feb 9, 2022
1 parent 0bd69c9 commit 6631bb8
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue to make the `step` argument in `WandbLogger.log_image` work ([#11716](https://github.com/PyTorchLightning/pytorch-lightning/pull/11716))
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
- With `DPStrategy`, the batch is not explictly moved to the device ([#11780](https://github.com/PyTorchLightning/pytorch-lightning/pull/11780))

- Fixed an issue to avoid val bar disappear after `trainer.validate()` ([#11700](https://github.com/PyTorchLightning/pytorch-lightning/pull/11700))


## [1.5.9] - 2022-01-18
Expand Down
2 changes: 1 addition & 1 deletion _notebooks
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ def init_predict_tqdm(self) -> Tqdm:
def init_validation_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for validation."""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
has_main_bar = self.trainer.state.fn != "validate"
bar = Tqdm(
desc="Validating",
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
leave=not has_main_bar,
dynamic_ncols=True,
file=sys.stdout,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ def test_tqdm_progress_bar_totals(tmpdir):
m = bar.total_val_batches
assert len(trainer.train_dataloader) == n
assert bar.main_progress_bar.total == n + m
assert bar.main_progress_bar.leave

# check val progress bar total
assert sum(len(loader) for loader in trainer.val_dataloaders) == m
assert bar.val_progress_bar.total == m
assert not bar.val_progress_bar.leave

# main progress bar should have reached the end (train batches + val batches)
assert bar.main_progress_bar.n == n + m
Expand All @@ -126,13 +128,15 @@ def test_tqdm_progress_bar_totals(tmpdir):
assert bar.val_progress_bar.total == m
assert bar.val_progress_bar.n == m
assert bar.val_batch_idx == m
assert bar.val_progress_bar.leave

trainer.test(model)

# check test progress bar total
k = bar.total_test_batches
assert sum(len(loader) for loader in trainer.test_dataloaders) == k
assert bar.test_progress_bar.total == k
assert bar.test_progress_bar.leave

# test progress bar should have reached the end
assert bar.test_progress_bar.n == k
Expand Down

0 comments on commit 6631bb8

Please sign in to comment.