Skip to content

Commit

Permalink
Re-enable Lite CLI on Windows + PyTorch 1.13 (#15645)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
3 people committed Dec 19, 2022
1 parent dc8cbcd commit 256213d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 29 deletions.
9 changes: 0 additions & 9 deletions src/lightning_fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int:

def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""

if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13: # pragma: no cover
# TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427
_log.error(
"On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug"
" upstream: https://github.com/pytorch/pytorch/issues/85427"
)
raise SystemExit(1)

import torch.distributed.run as torchrun

if args.strategy == "dp":
Expand Down
21 changes: 1 addition & 20 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,12 @@
from unittest.mock import Mock

import pytest
import torch.distributed.run
from tests_fabric.helpers.runif import RunIf

from lightning_fabric.cli import _run_model
from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13

if not (_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13):
import torch.distributed.run


def skip_windows_pt_1_13():
# https://github.com/pytorch/pytorch/issues/85427
return pytest.mark.skipif(
condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13),
reason="Torchelastic import bug in 1.13 affecting Windows",
)


@pytest.fixture
def fake_script(tmp_path):
Expand All @@ -40,7 +30,6 @@ def fake_script(tmp_path):
return str(script)


@skip_windows_pt_1_13()
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_defaults(monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
Expand All @@ -55,7 +44,6 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
assert os.environ["LT_PRECISION"] == "32"


@skip_windows_pt_1_13()
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
Expand All @@ -67,7 +55,6 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
assert os.environ["LT_ACCELERATOR"] == accelerator


@skip_windows_pt_1_13()
@pytest.mark.parametrize("strategy", ["dp", "ddp", "deepspeed"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
Expand All @@ -79,7 +66,6 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
assert os.environ["LT_STRATEGY"] == strategy


@skip_windows_pt_1_13()
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
Expand All @@ -92,7 +78,6 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):


@RunIf(mps=True)
@skip_windows_pt_1_13()
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
Expand All @@ -103,7 +88,6 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
assert os.environ["LT_DEVICES"] == "1"


@skip_windows_pt_1_13()
@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
Expand All @@ -114,7 +98,6 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
assert os.environ["LT_NUM_NODES"] == num_nodes


@skip_windows_pt_1_13()
@pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
Expand All @@ -125,7 +108,6 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
assert os.environ["LT_PRECISION"] == precision


@skip_windows_pt_1_13()
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_torchrun_defaults(monkeypatch, fake_script):
torchrun_mock = Mock()
Expand All @@ -145,7 +127,6 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
)


@skip_windows_pt_1_13()
@pytest.mark.parametrize(
"devices,expected",
[
Expand Down

0 comments on commit 256213d

Please sign in to comment.