Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rich main progress bar update #12618

Merged
merged 3 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

@property
def train_batch_idx(self) -> int:
"""The number of batches processed during training.
Expand Down
23 changes: 12 additions & 11 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,12 @@ def _add_task(self, total_batches: int, description: str, visible: bool = True)
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
total = self.progress.tasks[progress_bar_id].total
if not self._should_update(current, total):
return

leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
Expand Down Expand Up @@ -419,7 +423,7 @@ def on_predict_batch_start(
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update_metrics(trainer, pl_module)
self.refresh()

Expand All @@ -428,23 +432,20 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
# TODO: Use total val_processed here just like TQDM in a follow-up
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
self._update(self.test_progress_bar_id, self.test_batch_idx)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
)
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,6 @@ def is_enabled(self) -> bool:
def is_disabled(self) -> bool:
return not self.is_enabled

@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

def disable(self) -> None:
self._enabled = False

Expand Down
40 changes: 28 additions & 12 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):


@RunIf(rich=True)
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
@pytest.mark.parametrize(
"refresh_rate,train_batches,val_batches,expected_call_count",
[
(3, 6, 6, 4 + 3),
(4, 6, 6, 3 + 3),
(7, 6, 6, 2 + 2),
(1, 2, 3, 5 + 4),
(1, 0, 0, 0 + 0),
(3, 1, 0, 1 + 0),
(3, 1, 1, 1 + 2),
(3, 5, 0, 2 + 0),
(3, 5, 2, 3 + 2),
(6, 5, 2, 2 + 2),
],
)
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)
Expand All @@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call
trainer.fit(model)
assert progress_update.call_count == expected_call_count

fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_main_bar.completed == 12
assert fit_main_bar.total == 12
assert fit_main_bar.visible
assert fit_val_bar.completed == 6
assert fit_val_bar.total == 6
assert not fit_val_bar.visible
if train_batches > 0:
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
assert fit_main_bar.completed == train_batches + val_batches
assert fit_main_bar.total == train_batches + val_batches
assert fit_main_bar.visible
if val_batches > 0:
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_val_bar.completed == val_batches
assert fit_val_bar.total == val_batches
assert not fit_val_bar.visible


@RunIf(rich=True)
Expand Down