Skip to content

Commit

Permalink
Run main progress bar independent of val progress bar in `TQDMProgres…
Browse files Browse the repository at this point in the history
…sBar` (#12563)

Co-authored-by: carmocca <[email protected]>
  • Loading branch information
rohitgr7 and carmocca authored Apr 11, 2022
1 parent cf0e3c6 commit f4883d6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Run main progress bar updates independent of val progress bar updates in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563))


- Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))


Expand Down
21 changes: 12 additions & 9 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
if self._should_update(self.train_batch_idx, self.total_train_batches):
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -289,10 +290,12 @@ def on_validation_batch_start(
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader):
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == "fit":
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -313,7 +316,7 @@ def on_test_batch_start(
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader):
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -333,7 +336,7 @@ def on_predict_batch_start(
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader):
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -356,8 +359,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
s = sep.join(map(str, args))
active_progress_bar.write(s, **kwargs)

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand Down
61 changes: 51 additions & 10 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import pickle
import sys
Expand Down Expand Up @@ -347,10 +348,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
[0, 0, 3, None, None],
[1, 0, 3, [1], None],
[1, 1, 3, [1, 2], [1]],
[1, 1, 3, [2], [1]],
[5, 0, 3, [3, 5], None],
[5, 2, 3, [3, 5, 7], [2]],
[5, 2, 6, [5, 7], [2]],
[5, 2, 3, [3, 6, 7], [2]],
[5, 2, 6, [6, 7], [2]],
],
)
def test_main_progress_bar_update_amount(
Expand Down Expand Up @@ -549,16 +550,56 @@ def test_tqdm_progress_bar_can_be_pickled():
pickle.dumps(bar)


@RunIf(min_gpus=2, standalone=True)
@pytest.mark.parametrize(
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
)
def test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
):
limit_batches = 7
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
val_check_interval=val_check_interval,
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
callbacks=TQDMProgressBar(refresh_rate=3),
)
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
trainer.fit(model)

pbar = trainer.progress_bar_callback
assert pbar.main_progress_bar.n_values == main_progress_bar_updates
assert pbar.val_progress_bar.n_values == val_progress_bar_updates

val_check_batch = (
max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
)
assert trainer.val_check_batch == val_check_batch
val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
pbar_callback = trainer.progress_bar_callback
total_val_batches = limit_batches * val_checks_per_epoch

assert pbar_callback.val_progress_bar.n == limit_batches
assert pbar_callback.val_progress_bar.total == limit_batches
assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
assert pbar_callback.is_enabled


@RunIf(min_gpus=2, standalone=True)
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
world_size = 2
train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
total_train_samples = 16
train_batch_size = 4
total_val_samples = 2
val_batch_size = 1
train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)

model = BoringModel()
Expand All @@ -585,8 +626,8 @@ def test_progress_bar_max_val_check_interval(
assert pbar_callback.val_progress_bar.n == total_val_batches
assert pbar_callback.val_progress_bar.total == total_val_batches
total_val_batches = total_val_batches * val_checks_per_epoch
assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches
assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches
assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.is_enabled


Expand Down

0 comments on commit f4883d6

Please sign in to comment.