From 11012672433f35477f404ca6bc0d513add79087e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 16:20:03 +0100 Subject: [PATCH] Skip hanging spawn tests (#10838) Co-authored-by: Carlos Mocholi --- .../plugins/training_type/ddp_spawn.py | 17 ---------------- .../plugins/training_type/tpu_spawn.py | 20 +++++++++++++++++++ tests/helpers/runif.py | 11 ++++++++++ tests/loggers/test_all.py | 2 +- tests/loggers/test_tensorboard.py | 15 -------------- tests/plugins/test_tpu_spawn.py | 16 +++++++++++++++ tests/utilities/test_all_gather_grad.py | 2 +- 7 files changed, 49 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0b503db64e0a4..ff5159f739cdc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -25,7 +25,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.torch_distributed import broadcast_object_list @@ -171,17 +170,14 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def start_predicting(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]: @@ -444,16 +440,3 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() - - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - rank_zero_warn( - "When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`." - ) - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4fa0cfda6a859..92bd0f06735d8 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -24,6 +24,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO @@ -289,8 +290,17 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] + self._clean_logger(trainer) return super().start_training(trainer) + def start_evaluating(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_evaluating(trainer) + + def start_predicting(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_predicting(trainer) + def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -366,3 +376,13 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") + + @staticmethod + def _clean_logger(trainer: "pl.Trainer") -> None: + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] + for logger in loggers: + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. + # we want to make sure these are closed before we spawn our own threads. + # assuming nothing else references the experiment object, python should instantly `__del__` it. + logger._experiment = None diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 4ad6942aa160a..07bd6438da125 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -71,6 +71,7 @@ def __new__( deepspeed: bool = False, rich: bool = False, skip_49370: bool = False, + skip_hanging_spawn: bool = False, **kwargs, ): """ @@ -93,6 +94,7 @@ def __new__( deepspeed: if `deepspeed` module is required to run the test rich: if `rich` module is required to run the test skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370. + skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn. kwargs: native pytest.mark.skipif keyword arguments """ conditions = [] @@ -178,6 +180,15 @@ def __new__( conditions.append(ge_3_9 and old_torch) reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370") + if skip_hanging_spawn: + # strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ge_3_8 = Version(py_version) >= Version("3.8") + torch_version = get_distribution("torch").version + old_torch = Version(torch_version) < Version("1.9") + conditions.append(ge_3_8 and old_torch) + reasons.append("Impacted by hanging DDP spawn") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 370b24431b088..d66e77b4cea34 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -321,7 +321,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): assert pl_module.logger.experiment.something(foo="bar") is None -@RunIf(skip_windows=True, skip_49370=True) +@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True) @pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger]) def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): """Test that loggers get replaced by dummy loggers on global rank > 0.""" diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 0a99c058ef941..02a809aa2ab30 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -25,7 +25,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.base import LoggerCollection from pytorch_lightning.utilities.imports import _compare_version from tests.helpers import BoringModel @@ -333,17 +332,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog): assert logger.version == 0 assert "Missing logger folder:" in caplog.text - - -@pytest.mark.parametrize("use_list", [False, True]) -def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir): - tensorboard_logger = TensorBoardLogger(save_dir=tmpdir) - assert tensorboard_logger._experiment is None - tensorboard_logger.experiment # this property access will create the experiment - assert tensorboard_logger._experiment is not None - logger = [tensorboard_logger] if use_list else tensorboard_logger - trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger) - trainer.training_type_plugin._clean_logger(trainer) - if use_list: - assert isinstance(trainer.logger, LoggerCollection) - assert tensorboard_logger._experiment is None diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 3f4ff354e39bb..5f4abf560d6a6 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.plugins.training_type import TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -102,3 +103,18 @@ def test_model_tpu_one_core(): model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ + + +@RunIf(tpu=True) +@pytest.mark.parametrize("use_list", [False, True]) +def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir): + tensorboard_logger = TensorBoardLogger(save_dir=tmpdir) + assert tensorboard_logger._experiment is None + tensorboard_logger.experiment # this property access will create the experiment + assert tensorboard_logger._experiment is not None + logger = [tensorboard_logger] if use_list else tensorboard_logger + trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger) + trainer.training_type_plugin._clean_logger(trainer) + if use_list: + assert isinstance(trainer.logger, LoggerCollection) + assert tensorboard_logger._experiment is None diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 63c5c2cfe90fe..01ffd12a0ca62 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -41,7 +41,7 @@ def _test_all_gather_ddp(rank, world_size): assert torch.allclose(grad2, tensor2.grad) -@RunIf(skip_windows=True, skip_49370=True) +@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True) def test_all_gather_ddp_spawn(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)