Skip to content

Commit

Permalink
Fix deadlocks for distributed training for RichProgressBar (#10428)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Nov 9, 2021
1 parent 9ad69ab commit 0bdcd62
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.5.1] - 2021-MM-DD
## [1.5.1] - 2021-11-09

### Fixed

Expand All @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388))
- Fixed an import error being caused by `PostLocalSGD` when `torch.distributed` not available ([#10359](https://github.com/PyTorchLightning/pytorch-lightning/pull/10359))
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))


## [1.5.0] - 2021-11-02
Expand Down
57 changes: 32 additions & 25 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""

def __init__(self, trainer, pl_module):
def __init__(self, trainer):
self._trainer = trainer
self._pl_module = pl_module
self._tasks = {}
self._current_task_id = 0
self._metrics = {}
super().__init__()

def update(self, metrics):
# Called when metrics are ready to be rendered.
# This is to prevent render from causing deadlock issues by requesting metrics
# in separate threads.
self._metrics = metrics

def render(self, task) -> Text:
from pytorch_lightning.trainer.states import TrainerFn

Expand All @@ -149,14 +155,8 @@ def render(self, task) -> Text:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
# TODO(@daniellepintz): make this code cleaner
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
if progress_bar_callback:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics

for k, v in metrics.items():

for k, v in self._metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")

Expand Down Expand Up @@ -220,9 +220,9 @@ def __init__(
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
self._reset_progress_bar_ids()
self._metric_component = None
self._progress_stopped: bool = False
self.theme = theme
self._console: Console = Console()

@property
def refresh_rate_per_second(self) -> float:
Expand Down Expand Up @@ -263,12 +263,15 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

def _init_progress(self, trainer, pl_module):
if self.progress is None or self._progress_stopped:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self.progress = CustomProgress(
*self.configure_columns(trainer, pl_module),
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
disable=self.is_disabled,
console=self._console,
Expand All @@ -279,19 +282,19 @@ def _init_progress(self, trainer, pl_module):

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def __getstate__(self):
# can't pickle the rich progress objects
Expand All @@ -302,12 +305,11 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
# reset console reference after loading progress
self._console = Console()
state["_console"] = Console()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)

def on_sanity_check_end(self, trainer, pl_module):
Expand All @@ -328,10 +330,10 @@ def on_train_epoch_start(self, trainer, pl_module):
train_description = self._get_train_description(trainer.current_epoch)
if self.main_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
else:
elif self.progress is not None:
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
Expand Down Expand Up @@ -372,6 +374,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update_metrics(trainer, pl_module)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -414,6 +417,11 @@ def _reset_progress_bar_ids(self):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None

def _update_metrics(self, trainer, pl_module) -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)

def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
self._stop_progress()

Expand All @@ -436,7 +444,7 @@ def main_progress_bar(self) -> Task:
def test_progress_bar(self) -> Task:
return self.progress.tasks[self.test_progress_bar_id]

def configure_columns(self, trainer, pl_module) -> list:
def configure_columns(self, trainer) -> list:
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
Expand All @@ -447,5 +455,4 @@ def configure_columns(self, trainer, pl_module) -> list:
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
]
6 changes: 3 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def test_rich_progress_bar_configure_columns():
custom_column = TextColumn("[progress.description]Testing Rich!")

class CustomRichProgressBar(RichProgressBar):
def configure_columns(self, trainer, pl_module):
def configure_columns(self, trainer):
return [custom_column]

progress_bar = CustomRichProgressBar()

progress_bar._init_progress(Mock(), Mock())
progress_bar._init_progress(Mock())

assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 1
assert len(progress_bar.progress.columns) == 2


@RunIf(rich=True)
Expand Down

0 comments on commit 0bdcd62

Please sign in to comment.