Skip to content

Commit

Permalink
fix tqdm counter for multiple dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Jan 31, 2022
1 parent 86b177e commit 74ff052
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 57 deletions.
53 changes: 49 additions & 4 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):

def __init__(self) -> None:
self._trainer: Optional["pl.Trainer"] = None
self._val_progress: Optional[int] = None
self._test_progress: Optional[int] = None
self._predict_progress: Optional[int] = None

@property
def trainer(self) -> "pl.Trainer":
Expand All @@ -63,31 +66,70 @@ def train_batch_idx(self) -> int:
"""
return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed

def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
):
if self._val_progress is None or batch_idx == 0:
max_batches = trainer.num_sanity_val_batches if trainer.sanity_checking else trainer.num_val_batches
self._val_progress = sum(max_batches[:dataloader_idx])

def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
):
if self._test_progress is None or batch_idx == 0:
self._test_progress = sum(trainer.num_test_batches[:dataloader_idx])

def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
):
if self._predict_progress is None or batch_idx == 0:
self._predict_progress = sum(trainer.num_predict_batches[:dataloader_idx])

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._val_progress = None

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._test_progress = None

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._predict_progress = None

@property
def val_batch_idx(self) -> int:
"""The number of batches processed during validation.
Use this to update your progress bar.
"""
if self.trainer.state.fn == "fit":
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.fit_loop.epoch_loop.val_loop
else:
loop = self.trainer.validate_loop

current_batch_idx = loop.epoch_loop.batch_progress.current.processed
batch_idx = self._val_progress + current_batch_idx
return batch_idx

@property
def test_batch_idx(self) -> int:
"""The number of batches processed during testing.
Use this to update your progress bar.
"""
return self.trainer.test_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.test_loop
current_batch_idx = loop.epoch_loop.batch_progress.current.processed
batch_idx = self._test_progress + current_batch_idx
return batch_idx

@property
def predict_batch_idx(self) -> int:
"""The number of batches processed during prediction.
Use this to update your progress bar.
"""
return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed
loop = self.trainer.predict_loop
current_batch_idx = loop.epoch_loop.batch_progress.current.processed
batch_idx = self._predict_progress + current_batch_idx
return batch_idx

@property
def total_train_batches(self) -> Union[int, float]:
Expand All @@ -105,6 +147,9 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
if self.trainer.sanity_checking:
return sum(self.trainer.num_sanity_val_batches)

total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
Expand Down
51 changes: 27 additions & 24 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,7 @@ def is_disabled(self) -> bool:

@property
def _val_processed(self) -> int:
if self.trainer.state.fn == "fit":
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed

def disable(self) -> None:
self._enabled = False
Expand Down Expand Up @@ -227,12 +224,12 @@ def init_predict_tqdm(self) -> Tqdm:
def init_validation_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for validation."""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self._main_progress_bar is not None
has_main_bar = not self.trainer.state.fn == "validate"
bar = Tqdm(
desc="Validating",
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
leave=not has_main_bar,
dynamic_ncols=True,
file=sys.stdout,
)
Expand All @@ -256,6 +253,7 @@ def on_sanity_check_start(self, *_: Any) -> None:

def on_sanity_check_end(self, *_: Any) -> None:
self.main_progress_bar.close()
self.main_progress_bar = None
self.val_progress_bar.close()

def on_train_start(self, *_: Any) -> None:
Expand All @@ -273,63 +271,68 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
if self._should_update(self.train_batch_idx):
if self._should_update(self.train_batch_idx, self.total_train_batches):
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
if not self.main_progress_bar.disable:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, *_: Any) -> None:
self.main_progress_bar.close()

def on_validation_start(self, trainer: "pl.Trainer", *_: Any) -> None:
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.sanity_checking:
self.val_progress_bar.total = sum(trainer.num_sanity_val_batches)
else:
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_start(self, *args: Any, **kwargs: Any):
return super().on_validation_batch_start(*args, **kwargs)

def on_test_batch_start(self, *args: Any, **kwargs: Any):
return super().on_test_batch_start(*args, **kwargs)

def on_predict_batch_start(self, *args: Any, **kwargs: Any):
return super().on_predict_batch_start(*args, **kwargs)

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx):
if self._should_update(self.val_batch_idx, self.total_val_batches):
_update_n(self.val_progress_bar, self.val_batch_idx)
if trainer.state.fn == "fit":
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)

def on_validation_epoch_end(self, *_: Any) -> None:
_update_n(self.val_progress_bar, self._val_processed)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()
super().on_validation_end(trainer, pl_module)

def on_test_start(self, *_: Any) -> None:
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx):
if self._should_update(self.test_batch_idx, self.total_test_batches):
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_epoch_end(self, *_: Any) -> None:
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, *_: Any) -> None:
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
super().on_test_end(trainer, pl_module)

def on_predict_epoch_start(self, *_: Any) -> None:
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx):
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, *_: Any) -> None:
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
super().on_predict_end(trainer, pl_module)

def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
active_progress_bar = None
Expand All @@ -347,8 +350,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
s = sep.join(map(str, args))
active_progress_bar.write(s, **kwargs)

def _should_update(self, idx: int) -> bool:
return self.refresh_rate > 0 and idx % self.refresh_rate == 0
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def _get_max_batches(self) -> List[int]:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,9 @@ def _run_sanity_check(self) -> None:

# reload dataloaders
val_loop._reload_evaluation_dataloaders()
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]

# run eval step
with torch.no_grad():
Expand Down
Loading

0 comments on commit 74ff052

Please sign in to comment.