Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningFabric: Error handling for accelerator="mps" and ddp strategy pairing #16455

Merged
merged 7 commits into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed error handling for `accelerator="mps"` and `ddp` strategy pairing ([#16455](https://github.com/Lightning-AI/lightning/pull/16455))



## [1.9.0] - 2023-01-17
Expand Down
15 changes: 15 additions & 0 deletions src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning_fabric.strategies import (
DeepSpeedStrategy,
ParallelStrategy,
SingleDeviceStrategy,
SingleTPUStrategy,
Strategy,
Expand Down Expand Up @@ -201,6 +202,20 @@ def _check_config_and_set_final_flags(
f" Available names are: {', '.join(self._registered_accelerators)}."
)

# MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only.
is_ddp_str = isinstance(strategy, str) and "ddp" in strategy
is_dp_str = isinstance(strategy, str) and "dp" in strategy
is_deepspeed_str = isinstance(strategy, str) and "deepspeed" in strategy
is_parallel_strategy = isinstance(strategy, ParallelStrategy) or is_ddp_str or is_dp_str or is_deepspeed_str
is_mps_accelerator = MPSAccelerator.is_available() and (
accelerator in ("mps", "auto", "gpu", None) or isinstance(accelerator, MPSAccelerator)
)
if is_mps_accelerator and is_parallel_strategy:
raise ValueError(
f"You set `strategy={strategy}` but strategies from the DDP family are not supported on the"
f" MPS accelerator. Either explicitly set `accelerator='cpu'` or change the strategy."
)

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest
from tests_fabric.helpers.runif import RunIf

Expand All @@ -21,7 +23,8 @@

@RunIf(deepspeed=True)
@pytest.mark.parametrize("precision", ["bf16", 16, 32])
def test_deepspeed_precision_choice(precision):
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_deepspeed_precision_choice(_, precision):
"""Test to ensure precision plugin is correctly chosen.

DeepSpeed handles precision via custom DeepSpeedPrecision.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def test_deepspeed_multigpu_stage_3(tmpdir):

@RunIf(deepspeed=True)
@mock.patch("deepspeed.init_distributed", autospec=True)
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
def test_deepspeed_env_variables_on_platforms(deepspeed_dist_mock, tmpdir, platform):
def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
"""Test to ensure that we set up distributed communication correctly.

When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this.
Expand Down
35 changes: 30 additions & 5 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def creates_processes_externally(self) -> bool:
assert isinstance(connector.strategy.cluster_environment, CustomCluster)


@RunIf(mps=False)
@mock.patch.dict(
os.environ,
{
Expand Down Expand Up @@ -246,7 +247,8 @@ def test_interactive_incompatible_backend_error(_, monkeypatch):


@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_interactive_compatible_dp_strategy_gpu(_, monkeypatch):
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_interactive_compatible_dp_strategy_gpu(_, __, monkeypatch):
monkeypatch.setattr(lightning_fabric.utilities.imports, "_IS_INTERACTIVE", True)
connector = _Connector(strategy="dp", accelerator="gpu")
assert connector.strategy.launcher is None
Expand All @@ -266,6 +268,24 @@ def test_interactive_compatible_strategy_ddp_fork(monkeypatch):
assert connector.strategy.launcher.is_interactive_compatible


@RunIf(mps=True)
@pytest.mark.parametrize(
["strategy", "strategy_class"],
(
("ddp", DDPStrategy),
("dp", DataParallelStrategy),
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
),
)
@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()])
def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class):
with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
_Connector(accelerator=accelerator, strategy=strategy)

with pytest.raises(ValueError, match="strategies from the DDP family are not supported"):
_Connector(accelerator="mps", strategy=strategy_class())


@RunIf(mps=False)
@pytest.mark.parametrize(
["strategy", "strategy_class"],
Expand Down Expand Up @@ -353,6 +373,7 @@ def test_set_devices_if_none_cpu():
assert connector._parallel_devices == [torch.device("cpu")] * 3


@RunIf(mps=False)
def test_unsupported_strategy_types_on_cpu_and_fallback():
with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"):
connector = _Connector(strategy="dp", devices=2)
Expand Down Expand Up @@ -607,7 +628,8 @@ def test_strategy_choice_ddp_cpu_slurm(strategy):


@mock.patch.dict(os.environ, {}, clear=True)
def test_unsupported_tpu_choice(tpu_available):
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_unsupported_tpu_choice(_, tpu_available):
with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision=64\)` is not implemented"):
_Connector(accelerator="tpu", precision=64)

Expand Down Expand Up @@ -729,7 +751,8 @@ def test_gpu_accelerator_no_gpu_backend_found_error(*_):
"lightning_fabric.connector.torch.multiprocessing.get_all_start_methods",
return_value=[],
)
def test_ddp_fork_on_unsupported_platform(_, strategy):
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_ddp_fork_on_unsupported_platform(_, __, strategy):
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
_Connector(strategy=strategy)

Expand Down Expand Up @@ -765,7 +788,8 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin


@pytest.mark.parametrize(["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("Ddp", DDPStrategy)])
def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls):
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
connector = _Connector(strategy=strategy)
assert isinstance(connector.strategy, strategy_cls)

Expand Down Expand Up @@ -838,7 +862,8 @@ def test_arguments_from_environment_collision():


@RunIf(min_torch="1.12")
def test_fsdp_unsupported_on_cpu():
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_fsdp_unsupported_on_cpu(_):
"""Test that we raise an error if attempting to run FSDP without GPU."""
with pytest.raises(ValueError, match="You selected the FSDP strategy but FSDP is only available on GPU"):
_Connector(strategy="fsdp")
2 changes: 1 addition & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_backward():
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg")


@RunIf(deepspeed=True)
@RunIf(deepspeed=True, mps=False)
def test_backward_model_input_required():
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""
fabric = EmptyFabric(strategy="deepspeed")
Expand Down