diff --git a/src/lightning_lite/cli.py b/src/lightning_lite/cli.py index 6d18d52789f7d..6359cba46867b 100644 --- a/src/lightning_lite/cli.py +++ b/src/lightning_lite/cli.py @@ -20,7 +20,6 @@ from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning_lite.utilities.device_parser import _parse_gpu_ids -from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13 _log = logging.getLogger(__name__) @@ -148,15 +147,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int: def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: """This will invoke `torchrun` programmatically to launch the given script in new processes.""" - - if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13: # pragma: no cover - # TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427 - _log.error( - "On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug" - " upstream: https://github.com/pytorch/pytorch/issues/85427" - ) - raise SystemExit(1) - import torch.distributed.run as torchrun if args.strategy == "dp": diff --git a/tests/tests_lite/test_cli.py b/tests/tests_lite/test_cli.py index ad8d0ddd3240b..403d4a8e2fbb3 100644 --- a/tests/tests_lite/test_cli.py +++ b/tests/tests_lite/test_cli.py @@ -16,21 +16,10 @@ from unittest.mock import Mock import pytest +import torch.distributed.run from tests_lite.helpers.runif import RunIf from lightning_lite.cli import _run_model -from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13 - -if not (_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13): - import torch.distributed.run - - -def skip_windows_pt_1_13(): - # https://github.com/pytorch/pytorch/issues/85427 - return pytest.mark.skipif( - condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13), - reason="Torchelastic import bug in 1.13 affecting Windows", - ) @pytest.fixture @@ -40,7 +29,6 @@ def fake_script(tmp_path): return str(script) -@skip_windows_pt_1_13() @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_defaults(monkeypatch, fake_script): monkeypatch.setattr(torch.distributed, "run", Mock()) @@ -55,7 +43,6 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script): assert os.environ["LT_PRECISION"] == "32" -@skip_windows_pt_1_13() @pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -67,7 +54,6 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): assert os.environ["LT_ACCELERATOR"] == accelerator -@skip_windows_pt_1_13() @pytest.mark.parametrize("strategy", ["dp", "ddp", "deepspeed"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -79,7 +65,6 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script): assert os.environ["LT_STRATEGY"] == strategy -@skip_windows_pt_1_13() @pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -92,7 +77,6 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): @RunIf(mps=True) -@skip_windows_pt_1_13() @pytest.mark.parametrize("accelerator", ["mps", "gpu"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): @@ -103,7 +87,6 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): assert os.environ["LT_DEVICES"] == "1" -@skip_windows_pt_1_13() @pytest.mark.parametrize("num_nodes", ["1", "2", "3"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): @@ -114,7 +97,6 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): assert os.environ["LT_NUM_NODES"] == num_nodes -@skip_windows_pt_1_13() @pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_precision(precision, monkeypatch, fake_script): @@ -125,7 +107,6 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script): assert os.environ["LT_PRECISION"] == precision -@skip_windows_pt_1_13() @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_torchrun_defaults(monkeypatch, fake_script): torchrun_mock = Mock() @@ -145,7 +126,6 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script): ) -@skip_windows_pt_1_13() @pytest.mark.parametrize( "devices,expected", [