Skip to content

Commit

Permalink
Fix rich with uneven refresh rate tracking (#11668)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
rohitgr7 and carmocca authored Feb 3, 2022
1 parent 7948ed7 commit e9065e9
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 36 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))


- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
- Fixed `RichProgressBar` progress when refresh rate does not evenly divide the total counter ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))


- Fixed `RichProgressBar` progress validation bar total when using multiple validation runs within a single training epoch ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))


- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))


- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))


- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
if self.trainer.sanity_checking:
return sum(self.trainer.num_sanity_val_batches)

total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
Expand Down
30 changes: 15 additions & 15 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,6 @@ def on_validation_start(self, trainer, pl_module):

def on_sanity_check_start(self, trainer, pl_module):
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()

def on_sanity_check_end(self, trainer, pl_module):
if self.progress is not None:
Expand Down Expand Up @@ -349,32 +347,34 @@ def on_train_epoch_start(self, trainer, pl_module):
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
if self.total_val_batches > 0:
total_val_batches = self.total_val_batches
if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"):
# val can be checked multiple times per epoch
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()
if trainer.sanity_checking:
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
else:
self.val_progress_bar_id = self._add_task(
self.total_val_batches, self.validation_description, visible=False
)
self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()

def _should_update(self, current: int, total: int) -> bool:
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
self.refresh()

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.state.fn == "fit":
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def _get_max_batches(self) -> List[int]:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,9 @@ def _run_sanity_check(self) -> None:

# reload dataloaders
val_loop._reload_evaluation_dataloaders()
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]

# run eval step
with torch.no_grad():
Expand Down
91 changes: 75 additions & 16 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def test_rich_progress_bar_refresh_rate_enabled():


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)])
def test_rich_progress_bar(progress_update, tmpdir, dataset):
def test_rich_progress_bar(tmpdir, dataset):
class TestModel(BoringModel):
def train_dataloader(self):
return DataLoader(dataset=dataset)
Expand All @@ -62,25 +61,34 @@ def test_dataloader(self):
def predict_dataloader(self):
return DataLoader(dataset=dataset)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
max_steps=1,
max_epochs=1,
callbacks=RichProgressBar(),
)
model = TestModel()

trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.fit(model)
# 3 for main progress bar and 1 for val progress bar
assert mocked.call_count == 4

assert progress_update.call_count == 8
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.validate(model)
assert mocked.call_count == 1

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.test(model)
assert mocked.call_count == 1

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.predict(model)
assert mocked.call_count == 1


def test_rich_progress_bar_import_error(monkeypatch):
Expand Down Expand Up @@ -186,11 +194,20 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):

@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=4,
callbacks=RichProgressBar(refresh_rate=0),
)
trainer.fit(BoringModel())
assert progress_update.call_count == 0

model = BoringModel()

@RunIf(rich=True)
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
Expand All @@ -200,14 +217,26 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)

trainer.fit(model)
trainer.progress_bar_callback.on_train_start(trainer, model)
with mock.patch.object(
trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update
) as progress_update:
trainer.fit(model)
assert progress_update.call_count == expected_call_count

assert progress_update.call_count == expected_call_count
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_main_bar.completed == 12
assert fit_main_bar.total == 12
assert fit_main_bar.visible
assert fit_val_bar.completed == 6
assert fit_val_bar.total == 6
assert not fit_val_bar.visible


@RunIf(rich=True)
@pytest.mark.parametrize("limit_val_batches", (1, 5))
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches):
model = BoringModel()

progress_bar = RichProgressBar()
Expand All @@ -224,6 +253,36 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):

trainer.fit(model)
assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
assert progress_bar.progress.tasks[0].total == min(num_sanity_val_steps, limit_val_batches)


@RunIf(rich=True)
def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
"""Test the completed and total counter for rich progress bar when using val_check_interval."""
progress_bar = RichProgressBar()
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=2,
max_epochs=1,
limit_train_batches=7,
limit_val_batches=4,
callbacks=[progress_bar],
)
trainer.fit(model)

fit_main_progress_bar = progress_bar.progress.tasks[1]
assert fit_main_progress_bar.completed == 7 + 3 * 4
assert fit_main_progress_bar.total == 7 + 3 * 4

fit_val_bar = progress_bar.progress.tasks[2]
assert fit_val_bar.completed == 4
assert fit_val_bar.total == 4

trainer.validate(model)
val_bar = progress_bar.progress.tasks[0]
assert val_bar.completed == 4
assert val_bar.total == 4


@RunIf(rich=True)
Expand Down

0 comments on commit e9065e9

Please sign in to comment.