Skip to content

Commit

Permalink
Fix rich main progress bar update (#12618)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored and lexierule committed Apr 13, 2022
1 parent f87cff2 commit 9591747
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 29 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

@property
def train_batch_idx(self) -> int:
"""The number of batches processed during training.
Expand Down
25 changes: 13 additions & 12 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,19 @@ def _add_task(self, total_batches: int, description: str, visible: bool = True)
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
total = self.progress.tasks[progress_bar_id].total
if not self._should_update(current, total):
return

leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
return current % self.refresh_rate == 0 or current == total

def on_validation_epoch_end(self, trainer, pl_module):
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
Expand Down Expand Up @@ -419,7 +423,7 @@ def on_predict_batch_start(
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update_metrics(trainer, pl_module)
self.refresh()

Expand All @@ -428,23 +432,20 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
# TODO: Use total val_processed here just like TQDM in a follow-up
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
self._update(self.test_progress_bar_id, self.test_batch_idx)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
)
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,6 @@ def is_enabled(self) -> bool:
def is_disabled(self) -> bool:
return not self.is_enabled

@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

def disable(self) -> None:
self._enabled = False

Expand Down
40 changes: 28 additions & 12 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):


@RunIf(rich=True)
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
@pytest.mark.parametrize(
"refresh_rate,train_batches,val_batches,expected_call_count",
[
(3, 6, 6, 4 + 3),
(4, 6, 6, 3 + 3),
(7, 6, 6, 2 + 2),
(1, 2, 3, 5 + 4),
(1, 0, 0, 0 + 0),
(3, 1, 0, 1 + 0),
(3, 1, 1, 1 + 2),
(3, 5, 0, 2 + 0),
(3, 5, 2, 3 + 2),
(6, 5, 2, 2 + 2),
],
)
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)
Expand All @@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call
trainer.fit(model)
assert progress_update.call_count == expected_call_count

fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_main_bar.completed == 12
assert fit_main_bar.total == 12
assert fit_main_bar.visible
assert fit_val_bar.completed == 6
assert fit_val_bar.total == 6
assert not fit_val_bar.visible
if train_batches > 0:
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
assert fit_main_bar.completed == train_batches + val_batches
assert fit_main_bar.total == train_batches + val_batches
assert fit_main_bar.visible
if val_batches > 0:
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_val_bar.completed == val_batches
assert fit_val_bar.total == val_batches
assert not fit_val_bar.visible


@RunIf(rich=True)
Expand Down

0 comments on commit 9591747

Please sign in to comment.