Skip to content

Commit

Permalink
Error handling for accelerator="mps" and ddp strategy pairing (#1…
Browse files Browse the repository at this point in the history
…6153)

Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Nikhil Shenoy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: awaelchli <[email protected]>
Fixes #16148
  • Loading branch information
shenoynikhil authored Jan 12, 2023
1 parent 426c463 commit 12a4e71
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 24 deletions.
15 changes: 15 additions & 0 deletions src/pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
HorovodStrategy,
HPUParallelStrategy,
IPUStrategy,
ParallelStrategy,
SingleDeviceStrategy,
SingleHPUStrategy,
SingleTPUStrategy,
Expand Down Expand Up @@ -284,6 +285,20 @@ def _check_config_and_set_final_flags(
f" Available names are: {', '.join(self._accelerator_types)}."
)

# MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only.
is_ddp_str = isinstance(strategy, str) and "ddp" in strategy
is_dp_str = isinstance(strategy, str) and "dp" in strategy
is_deepspeed_str = isinstance(strategy, str) and "deepspeed" in strategy
is_parallel_strategy = isinstance(strategy, ParallelStrategy) or is_ddp_str or is_dp_str or is_deepspeed_str
is_mps_accelerator = MPSAccelerator.is_available() and (
accelerator in ("mps", "auto", "gpu", None) or isinstance(accelerator, MPSAccelerator)
)
if is_mps_accelerator and is_parallel_strategy:
raise ValueError(
f"You set `strategy={strategy}` but strategies from the DDP family are not supported on the"
f" MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy."
)

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ def environment_combinations():
yield environment, variables, expected


@RunIf(mps=False)
@pytest.mark.parametrize(
"strategy_cls",
[DDPStrategy, DDPShardedStrategy, pytest.param(DeepSpeedStrategy, marks=RunIf(deepspeed=True))],
)
@mock.patch("pytorch_lightning.accelerators.cuda.CUDAAccelerator.is_available", return_value=True)
def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strategy_cls):
def test_ranks_available_manual_strategy_selection(_, strategy_cls):
"""Test that the rank information is readily available after Trainer initialization."""
num_nodes = 2
for cluster, variables, expected in environment_combinations():
Expand All @@ -77,6 +78,7 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat
assert trainer.world_size == expected["world_size"]


@RunIf(mps=False)
@pytest.mark.parametrize(
"trainer_kwargs",
[
Expand All @@ -86,7 +88,7 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat
dict(strategy="ddp_spawn", accelerator="gpu", devices=[1, 2]),
],
)
def test_ranks_available_automatic_strategy_selection(mps_count_4, cuda_count_4, trainer_kwargs):
def test_ranks_available_automatic_strategy_selection(cuda_count_4, trainer_kwargs):
"""Test that the rank information is readily available after Trainer initialization."""
num_nodes = 2
trainer_kwargs.update(num_nodes=num_nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def creates_processes_externally(self) -> bool:
assert isinstance(trainer.strategy.cluster_environment, CustomCluster)


@RunIf(mps=False)
@mock.patch.dict(
os.environ,
{
Expand Down Expand Up @@ -231,7 +232,8 @@ def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_env
assert isinstance(trainer.strategy.cluster_environment, expected_environment)


def test_interactive_incompatible_backend_error(mps_count_2, cuda_count_2, monkeypatch):
@RunIf(mps=False)
def test_interactive_incompatible_backend_error(cuda_count_2, monkeypatch):
monkeypatch.setattr(pytorch_lightning.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True)
with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"):
Trainer(strategy="ddp", accelerator="gpu", devices=2)
Expand All @@ -247,7 +249,7 @@ def test_interactive_incompatible_backend_error(mps_count_2, cuda_count_2, monke
Trainer(strategy="dp")


def test_interactive_compatible_dp_strategy_gpu(cuda_count_2, monkeypatch):
def test_interactive_compatible_dp_strategy_gpu(mps_count_0, cuda_count_2, monkeypatch):
monkeypatch.setattr(pytorch_lightning.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True)
trainer = Trainer(strategy="dp", accelerator="gpu")
assert trainer.strategy.launcher is None
Expand Down Expand Up @@ -358,7 +360,7 @@ def test_set_devices_if_none_cpu():

def test_unsupported_strategy_types_on_cpu_and_fallback():
with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"):
trainer = Trainer(strategy="dp", num_processes=2)
trainer = Trainer(accelerator="cpu", strategy="dp", num_processes=2)
assert isinstance(trainer.strategy, DDPStrategy)


Expand All @@ -369,6 +371,28 @@ def test_exception_invalid_strategy():
Trainer(strategy="tpu_spawn")


@pytest.mark.parametrize(
["strategy", "strategy_class"],
(
("ddp_spawn", DDPSpawnStrategy),
("ddp_spawn_find_unused_parameters_false", DDPSpawnStrategy),
("ddp", DDPStrategy),
("ddp_find_unused_parameters_false", DDPStrategy),
("dp", DataParallelStrategy),
("ddp_sharded", DDPShardedStrategy),
("ddp_sharded_spawn", DDPSpawnShardedStrategy),
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
),
)
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()])
def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class, mps_count_1, cuda_count_0):
with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
Trainer(accelerator=accelerator, strategy=strategy)

with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
Trainer(accelerator="mps", strategy=strategy_class())


@pytest.mark.parametrize(
["strategy", "strategy_class"],
[
Expand Down Expand Up @@ -475,14 +499,6 @@ def test_strategy_choice_ddp_cuda(strategy, expected_cls, mps_count_0, cuda_coun
assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment)


@pytest.mark.parametrize("strategy,expected_cls", [("ddp", DDPStrategy), ("ddp_spawn", DDPSpawnStrategy)])
def test_strategy_choice_ddp_mps(strategy, expected_cls, mps_count_1, cuda_count_0):
trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="gpu", devices=1)
assert isinstance(trainer.accelerator, MPSAccelerator)
assert isinstance(trainer.strategy, expected_cls)
assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment)


@pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)])
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy])
def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_env):
Expand Down Expand Up @@ -704,9 +720,9 @@ def test_deterministic_init(deterministic):
(False, [Mock(spec=LayerSync)], LayerSync),
],
)
def test_sync_batchnorm_set(tmpdir, sync_batchnorm, plugins, expected):
def test_sync_batchnorm_set(sync_batchnorm, plugins, expected):
"""Test valid combinations of the sync_batchnorm Trainer flag and the plugins list of layer-sync plugins."""
trainer = Trainer(sync_batchnorm=sync_batchnorm, plugins=plugins, strategy="ddp")
trainer = Trainer(accelerator="cpu", sync_batchnorm=sync_batchnorm, plugins=plugins, strategy="ddp")
assert isinstance(trainer._accelerator_connector._layer_sync, expected)
assert isinstance(trainer.strategy._layer_sync, expected)

Expand All @@ -733,7 +749,7 @@ def __init__(self, **kwargs):

strategy = CustomParallelStrategy()
assert strategy._layer_sync is None
Trainer(strategy=strategy, sync_batchnorm=True)
Trainer(accelerator="cpu", strategy=strategy, sync_batchnorm=True)
assert isinstance(strategy._layer_sync, NativeSyncBatchNorm)


Expand Down Expand Up @@ -809,12 +825,12 @@ def test_accelerator_specific_checkpoint_io(*_):
)
def test_ddp_fork_on_unsupported_platform(_, strategy):
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
Trainer(strategy=strategy)
Trainer(accelerator="cpu", strategy=strategy)


@pytest.mark.parametrize(
["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("DDP_FIND_UNUSED_PARAMETERS_FALSE", DDPStrategy)]
)
def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls):
trainer = Trainer(strategy=strategy)
trainer = Trainer(accelerator="cpu", strategy=strategy)
assert isinstance(trainer.strategy, strategy_cls)
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length):
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
@pytest.mark.parametrize("use_fault_tolerant", [False, True])
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_validation_test(mps_count_2, cuda_count_2, use_fault_tolerant, replace_sampler_ddp):
def test_combined_data_loader_validation_test(mps_count_0, cuda_count_2, use_fault_tolerant, replace_sampler_ddp):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using
CombinedLoader."""

Expand Down
12 changes: 7 additions & 5 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ def on_exception(self, *_):
self.exceptions += 1


@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True))])
@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True, mps=False))])
def test_error_handling_all_stages(tmpdir, strategy):
model = TrainerStagesErrorsModel()
counter = ExceptionCounter()
Expand Down Expand Up @@ -2017,9 +2017,11 @@ def training_step(self, batch, batch_idx):
["trainer_kwargs", "strategy_cls", "strategy_name", "accelerator_cls", "devices"],
[
({"strategy": None}, SingleDeviceStrategy, "single_device", CPUAccelerator, 1),
({"strategy": "dp"}, DDPStrategy, "ddp", CPUAccelerator, 1),
({"strategy": "ddp"}, DDPStrategy, "ddp", CPUAccelerator, 1),
({"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", CPUAccelerator, 1),
pytest.param({"strategy": "dp"}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)),
pytest.param({"strategy": "ddp"}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)),
pytest.param(
{"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)
),
(
{"strategy": None, "accelerator": "cuda", "devices": 1},
SingleDeviceStrategy,
Expand Down Expand Up @@ -2075,7 +2077,7 @@ def training_step(self, batch, batch_idx):
CUDAAccelerator,
2,
),
({"strategy": DDPStrategy()}, DDPStrategy, "ddp", CPUAccelerator, 1),
pytest.param({"strategy": DDPStrategy()}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)),
({"strategy": DDPStrategy(), "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2),
(
{"strategy": DataParallelStrategy(), "accelerator": "cuda", "devices": 2},
Expand Down

0 comments on commit 12a4e71

Please sign in to comment.