Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run main progress bar independent of val progress bar in TQDMProgressBar #12563

Merged
merged 8 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Run main progress bar independent of val progress bar in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


-
Expand Down
21 changes: 12 additions & 9 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
if self._should_update(self.train_batch_idx, self.total_train_batches):
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
current = self.train_batch_idx + self._val_processed
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -294,10 +295,12 @@ def on_validation_batch_start(
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader):
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == "fit":
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

current = self.train_batch_idx + self._val_processed
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -318,7 +321,7 @@ def on_test_batch_start(
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -338,7 +341,7 @@ def on_predict_batch_start(
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader):
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -361,8 +364,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
s = sep.join(map(str, args))
active_progress_bar.write(s, **kwargs)

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand Down
61 changes: 51 additions & 10 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import pickle
import sys
Expand Down Expand Up @@ -347,10 +348,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
[0, 0, 3, None, None],
[1, 0, 3, [1], None],
[1, 1, 3, [1, 2], [1]],
carmocca marked this conversation as resolved.
Show resolved Hide resolved
[1, 1, 3, [2], [1]],
[5, 0, 3, [3, 5], None],
[5, 2, 3, [3, 5, 7], [2]],
[5, 2, 6, [5, 7], [2]],
carmocca marked this conversation as resolved.
Show resolved Hide resolved
[5, 2, 3, [3, 6, 7], [2]],
[5, 2, 6, [6, 7], [2]],
],
)
def test_main_progress_bar_update_amount(
Expand Down Expand Up @@ -549,16 +550,56 @@ def test_tqdm_progress_bar_can_be_pickled():
pickle.dumps(bar)


@RunIf(min_gpus=2, standalone=True)
@pytest.mark.parametrize(
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
)
def test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
):
limit_batches = 7
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
val_check_interval=val_check_interval,
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
callbacks=TQDMProgressBar(refresh_rate=3),
)
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
trainer.fit(model)

pbar = trainer.progress_bar_callback
assert pbar.main_progress_bar.n_values == main_progress_bar_updates
assert pbar.val_progress_bar.n_values == val_progress_bar_updates

val_check_batch = (
max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
)
assert trainer.val_check_batch == val_check_batch
val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
pbar_callback = trainer.progress_bar_callback
total_val_batches = limit_batches * val_checks_per_epoch

assert pbar_callback.val_progress_bar.n == limit_batches
assert pbar_callback.val_progress_bar.total == limit_batches
assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
assert pbar_callback.is_enabled


@RunIf(min_gpus=2, standalone=True)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
world_size = 2
train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
total_train_samples = 16
train_batch_size = 4
total_val_samples = 2
val_batch_size = 1
train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)

model = BoringModel()
Expand All @@ -585,8 +626,8 @@ def test_progress_bar_max_val_check_interval(
assert pbar_callback.val_progress_bar.n == total_val_batches
assert pbar_callback.val_progress_bar.total == total_val_batches
total_val_batches = total_val_batches * val_checks_per_epoch
assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches
assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches
assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.is_enabled


Expand Down