Skip to content

Commit

Permalink
Skip hanging spawn tests (#10838)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Dec 7, 2021
1 parent e0a1f55 commit 1101267
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 34 deletions.
17 changes: 0 additions & 17 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
Expand Down Expand Up @@ -171,17 +170,14 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)

def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
Expand Down Expand Up @@ -444,16 +440,3 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
rank_zero_warn(
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
)
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None
20 changes: 20 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
Expand Down Expand Up @@ -289,8 +290,17 @@ def start_training(self, trainer: "pl.Trainer") -> None:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._clean_logger(trainer)
return super().start_training(trainer)

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
return super().start_evaluating(trainer)

def start_predicting(self, trainer: "pl.Trainer") -> None:
self._clean_logger(trainer)
return super().start_predicting(trainer)

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

Expand Down Expand Up @@ -366,3 +376,13 @@ def checkpoint_io(self) -> CheckpointIO:
@checkpoint_io.setter
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")

@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None
11 changes: 11 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __new__(
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
skip_hanging_spawn: bool = False,
**kwargs,
):
"""
Expand All @@ -93,6 +94,7 @@ def __new__(
deepspeed: if `deepspeed` module is required to run the test
rich: if `rich` module is required to run the test
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn.
kwargs: native pytest.mark.skipif keyword arguments
"""
conditions = []
Expand Down Expand Up @@ -178,6 +180,15 @@ def __new__(
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")

if skip_hanging_spawn:
# strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
ge_3_8 = Version(py_version) >= Version("3.8")
torch_version = get_distribution("torch").version
old_torch = Version(torch_version) < Version("1.9")
conditions.append(ge_3_8 and old_torch)
reasons.append("Impacted by hanging DDP spawn")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
assert pl_module.logger.experiment.something(foo="bar") is None


@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
Expand Down
15 changes: 0 additions & 15 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from tests.helpers import BoringModel

Expand Down Expand Up @@ -333,17 +332,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
assert logger.version == 0

assert "Missing logger folder:" in caplog.text


@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
assert tensorboard_logger._experiment is None
tensorboard_logger.experiment # this property access will create the experiment
assert tensorboard_logger._experiment is not None
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None
16 changes: 16 additions & 0 deletions tests/plugins/test_tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
Expand Down Expand Up @@ -102,3 +103,18 @@ def test_model_tpu_one_core():
model = BoringModelTPU()
trainer.fit(model)
assert "PT_XLA_DEBUG" not in os.environ


@RunIf(tpu=True)
@pytest.mark.parametrize("use_list", [False, True])
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
assert tensorboard_logger._experiment is None
tensorboard_logger.experiment # this property access will create the experiment
assert tensorboard_logger._experiment is not None
logger = [tensorboard_logger] if use_list else tensorboard_logger
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
trainer.training_type_plugin._clean_logger(trainer)
if use_list:
assert isinstance(trainer.logger, LoggerCollection)
assert tensorboard_logger._experiment is None
2 changes: 1 addition & 1 deletion tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _test_all_gather_ddp(rank, world_size):
assert torch.allclose(grad2, tensor2.grad)


@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
def test_all_gather_ddp_spawn():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)
Expand Down

0 comments on commit 1101267

Please sign in to comment.