Skip to content

Commit

Permalink
Delete TensorBoardLogger experiment before spawning the processes. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored and lexierule committed Nov 30, 2021
1 parent a5e8823 commit 7f147ee
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))


- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777))


## [1.5.2] - 2021-11-16

### Fixed
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
Expand Down Expand Up @@ -170,14 +171,17 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
Expand Down Expand Up @@ -440,3 +444,16 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn(
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
)
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None
9 changes: 0 additions & 9 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")

def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {
"nprocs": len(self.parallel_devices),
Expand Down Expand Up @@ -293,13 +289,8 @@ def start_training(self, trainer: "pl.Trainer") -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._close_logger(trainer)
return super().start_training(trainer)

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._close_logger(trainer)
return super().start_evaluating(trainer)

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

Expand Down
15 changes: 15 additions & 0 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from tests.helpers import BoringModel

Expand Down Expand Up @@ -332,3 +333,17 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
assert logger.version == 0

assert "Missing logger folder:" in caplog.text


@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
assert tensorboard_logger._experiment is None
tensorboard_logger.experiment # this property access will create the experiment
assert tensorboard_logger._experiment is not None
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None

0 comments on commit 7f147ee

Please sign in to comment.