From ed7ae82bb4b8906308815cce18e8eb6a426b2515 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 28 Mar 2022 18:50:06 +0530 Subject: [PATCH] fix tqdm standalone test --- tests/callbacks/test_tqdm_progress_bar.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 06af9aaeff555..68ef7d35290e9 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -594,12 +594,14 @@ def test_progress_bar_max_val_check_interval( val_checks_per_epoch = total_train_batches / val_check_batch total_val_batches = total_val_samples // (val_batch_size * world_size) pbar_callback = trainer.progress_bar_callback - assert pbar_callback.val_progress_bar.n == total_val_batches - assert pbar_callback.val_progress_bar.total == total_val_batches - total_val_batches = total_val_batches * val_checks_per_epoch - assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches - assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches - assert pbar_callback.is_enabled == trainer.is_global_zero + + if trainer.is_global_zero: + assert pbar_callback.val_progress_bar.n == total_val_batches + assert pbar_callback.val_progress_bar.total == total_val_batches + total_val_batches = total_val_batches * val_checks_per_epoch + assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches + assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches + assert pbar_callback.is_enabled def test_get_progress_bar_metrics(tmpdir: str):