From 12a4e71763e8477bbddb5915fcdfe0aed194595f Mon Sep 17 00:00:00 2001 From: Nikhil Shenoy Date: Thu, 12 Jan 2023 05:10:11 -0800 Subject: [PATCH] Error handling for `accelerator="mps"` and `ddp` strategy pairing (#16153) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Nikhil Shenoy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: awaelchli Fixes https://github.com/Lightning-AI/lightning/issues/16148 --- .../connectors/accelerator_connector.py | 15 ++++++ .../plugins/test_cluster_integration.py | 6 ++- .../connectors/test_accelerator_connector.py | 48 ++++++++++++------- .../tests_pytorch/trainer/test_supporters.py | 2 +- tests/tests_pytorch/trainer/test_trainer.py | 12 +++-- 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index f43f9fb96fce7..2e74ff45fe105 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -69,6 +69,7 @@ HorovodStrategy, HPUParallelStrategy, IPUStrategy, + ParallelStrategy, SingleDeviceStrategy, SingleHPUStrategy, SingleTPUStrategy, @@ -284,6 +285,20 @@ def _check_config_and_set_final_flags( f" Available names are: {', '.join(self._accelerator_types)}." ) + # MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only. + is_ddp_str = isinstance(strategy, str) and "ddp" in strategy + is_dp_str = isinstance(strategy, str) and "dp" in strategy + is_deepspeed_str = isinstance(strategy, str) and "deepspeed" in strategy + is_parallel_strategy = isinstance(strategy, ParallelStrategy) or is_ddp_str or is_dp_str or is_deepspeed_str + is_mps_accelerator = MPSAccelerator.is_available() and ( + accelerator in ("mps", "auto", "gpu", None) or isinstance(accelerator, MPSAccelerator) + ) + if is_mps_accelerator and is_parallel_strategy: + raise ValueError( + f"You set `strategy={strategy}` but strategies from the DDP family are not supported on the" + f" MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy." + ) + self._accelerator_flag = accelerator supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index c6542f0797743..e8beecf15020a 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -56,12 +56,13 @@ def environment_combinations(): yield environment, variables, expected +@RunIf(mps=False) @pytest.mark.parametrize( "strategy_cls", [DDPStrategy, DDPShardedStrategy, pytest.param(DeepSpeedStrategy, marks=RunIf(deepspeed=True))], ) @mock.patch("pytorch_lightning.accelerators.cuda.CUDAAccelerator.is_available", return_value=True) -def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strategy_cls): +def test_ranks_available_manual_strategy_selection(_, strategy_cls): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 for cluster, variables, expected in environment_combinations(): @@ -77,6 +78,7 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat assert trainer.world_size == expected["world_size"] +@RunIf(mps=False) @pytest.mark.parametrize( "trainer_kwargs", [ @@ -86,7 +88,7 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat dict(strategy="ddp_spawn", accelerator="gpu", devices=[1, 2]), ], ) -def test_ranks_available_automatic_strategy_selection(mps_count_4, cuda_count_4, trainer_kwargs): +def test_ranks_available_automatic_strategy_selection(cuda_count_4, trainer_kwargs): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 trainer_kwargs.update(num_nodes=num_nodes) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index b35ee92ff1852..c0da8086a8b84 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -126,6 +126,7 @@ def creates_processes_externally(self) -> bool: assert isinstance(trainer.strategy.cluster_environment, CustomCluster) +@RunIf(mps=False) @mock.patch.dict( os.environ, { @@ -231,7 +232,8 @@ def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_env assert isinstance(trainer.strategy.cluster_environment, expected_environment) -def test_interactive_incompatible_backend_error(mps_count_2, cuda_count_2, monkeypatch): +@RunIf(mps=False) +def test_interactive_incompatible_backend_error(cuda_count_2, monkeypatch): monkeypatch.setattr(pytorch_lightning.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): Trainer(strategy="ddp", accelerator="gpu", devices=2) @@ -247,7 +249,7 @@ def test_interactive_incompatible_backend_error(mps_count_2, cuda_count_2, monke Trainer(strategy="dp") -def test_interactive_compatible_dp_strategy_gpu(cuda_count_2, monkeypatch): +def test_interactive_compatible_dp_strategy_gpu(mps_count_0, cuda_count_2, monkeypatch): monkeypatch.setattr(pytorch_lightning.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True) trainer = Trainer(strategy="dp", accelerator="gpu") assert trainer.strategy.launcher is None @@ -358,7 +360,7 @@ def test_set_devices_if_none_cpu(): def test_unsupported_strategy_types_on_cpu_and_fallback(): with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"): - trainer = Trainer(strategy="dp", num_processes=2) + trainer = Trainer(accelerator="cpu", strategy="dp", num_processes=2) assert isinstance(trainer.strategy, DDPStrategy) @@ -369,6 +371,28 @@ def test_exception_invalid_strategy(): Trainer(strategy="tpu_spawn") +@pytest.mark.parametrize( + ["strategy", "strategy_class"], + ( + ("ddp_spawn", DDPSpawnStrategy), + ("ddp_spawn_find_unused_parameters_false", DDPSpawnStrategy), + ("ddp", DDPStrategy), + ("ddp_find_unused_parameters_false", DDPStrategy), + ("dp", DataParallelStrategy), + ("ddp_sharded", DDPShardedStrategy), + ("ddp_sharded_spawn", DDPSpawnShardedStrategy), + pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), + ), +) +@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()]) +def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class, mps_count_1, cuda_count_0): + with pytest.raises(ValueError, match="strategies from the DDP family are not supported"): + Trainer(accelerator=accelerator, strategy=strategy) + + with pytest.raises(ValueError, match="strategies from the DDP family are not supported"): + Trainer(accelerator="mps", strategy=strategy_class()) + + @pytest.mark.parametrize( ["strategy", "strategy_class"], [ @@ -475,14 +499,6 @@ def test_strategy_choice_ddp_cuda(strategy, expected_cls, mps_count_0, cuda_coun assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) -@pytest.mark.parametrize("strategy,expected_cls", [("ddp", DDPStrategy), ("ddp_spawn", DDPSpawnStrategy)]) -def test_strategy_choice_ddp_mps(strategy, expected_cls, mps_count_1, cuda_count_0): - trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="gpu", devices=1) - assert isinstance(trainer.accelerator, MPSAccelerator) - assert isinstance(trainer.strategy, expected_cls) - assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) - - @pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)]) @pytest.mark.parametrize("strategy", ["ddp", DDPStrategy]) def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_env): @@ -704,9 +720,9 @@ def test_deterministic_init(deterministic): (False, [Mock(spec=LayerSync)], LayerSync), ], ) -def test_sync_batchnorm_set(tmpdir, sync_batchnorm, plugins, expected): +def test_sync_batchnorm_set(sync_batchnorm, plugins, expected): """Test valid combinations of the sync_batchnorm Trainer flag and the plugins list of layer-sync plugins.""" - trainer = Trainer(sync_batchnorm=sync_batchnorm, plugins=plugins, strategy="ddp") + trainer = Trainer(accelerator="cpu", sync_batchnorm=sync_batchnorm, plugins=plugins, strategy="ddp") assert isinstance(trainer._accelerator_connector._layer_sync, expected) assert isinstance(trainer.strategy._layer_sync, expected) @@ -733,7 +749,7 @@ def __init__(self, **kwargs): strategy = CustomParallelStrategy() assert strategy._layer_sync is None - Trainer(strategy=strategy, sync_batchnorm=True) + Trainer(accelerator="cpu", strategy=strategy, sync_batchnorm=True) assert isinstance(strategy._layer_sync, NativeSyncBatchNorm) @@ -809,12 +825,12 @@ def test_accelerator_specific_checkpoint_io(*_): ) def test_ddp_fork_on_unsupported_platform(_, strategy): with pytest.raises(ValueError, match="process forking is not supported on this platform"): - Trainer(strategy=strategy) + Trainer(accelerator="cpu", strategy=strategy) @pytest.mark.parametrize( ["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("DDP_FIND_UNUSED_PARAMETERS_FALSE", DDPStrategy)] ) def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls): - trainer = Trainer(strategy=strategy) + trainer = Trainer(accelerator="cpu", strategy=strategy) assert isinstance(trainer.strategy, strategy_cls) diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 15958500c2dec..8533ad5fdb467 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -316,7 +316,7 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @pytest.mark.parametrize("use_fault_tolerant", [False, True]) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) -def test_combined_data_loader_validation_test(mps_count_2, cuda_count_2, use_fault_tolerant, replace_sampler_ddp): +def test_combined_data_loader_validation_test(mps_count_0, cuda_count_2, use_fault_tolerant, replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 3b4af83b13cfd..edace5429a531 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1921,7 +1921,7 @@ def on_exception(self, *_): self.exceptions += 1 -@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True))]) +@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True, mps=False))]) def test_error_handling_all_stages(tmpdir, strategy): model = TrainerStagesErrorsModel() counter = ExceptionCounter() @@ -2017,9 +2017,11 @@ def training_step(self, batch, batch_idx): ["trainer_kwargs", "strategy_cls", "strategy_name", "accelerator_cls", "devices"], [ ({"strategy": None}, SingleDeviceStrategy, "single_device", CPUAccelerator, 1), - ({"strategy": "dp"}, DDPStrategy, "ddp", CPUAccelerator, 1), - ({"strategy": "ddp"}, DDPStrategy, "ddp", CPUAccelerator, 1), - ({"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", CPUAccelerator, 1), + pytest.param({"strategy": "dp"}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)), + pytest.param({"strategy": "ddp"}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)), + pytest.param( + {"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False) + ), ( {"strategy": None, "accelerator": "cuda", "devices": 1}, SingleDeviceStrategy, @@ -2075,7 +2077,7 @@ def training_step(self, batch, batch_idx): CUDAAccelerator, 2, ), - ({"strategy": DDPStrategy()}, DDPStrategy, "ddp", CPUAccelerator, 1), + pytest.param({"strategy": DDPStrategy()}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)), ({"strategy": DDPStrategy(), "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), ( {"strategy": DataParallelStrategy(), "accelerator": "cuda", "devices": 2},