Skip to content

Commit

Permalink
Multinode on MPS (#15748)
Browse files Browse the repository at this point in the history
* Fix restarting attribute for lr finder
* update lite executor
* update trainer executor
* update spawn executor
* add multinode component tests
* add testing helpers
* add lite tests
* add trainer tests
* update changelog
* update trainer
* update workflow
* update tests
* debug
* add reason for skipif
* Apply suggestions from code review
* switch skipif

Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
6 people authored Dec 8, 2022
1 parent 8475f85 commit 36aecde
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-app-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:

- name: Adjust tests
if: ${{ matrix.pkg-name == 'lightning' }}
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app" --target_import="lightning.app"
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app,lightning_lite,pytorch_lightning" --target_import="lightning.app,lightning.lite,lightning.pytorch"

- name: Adjust examples
if: ${{ matrix.pkg-name != 'lightning' }}
Expand Down
4 changes: 4 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))

- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))



- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))

Expand Down Expand Up @@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))
- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))
- Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712))
- Fixed setting property to the LightningFlow ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))
- Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714))


Expand Down
37 changes: 30 additions & 7 deletions src/lightning_app/components/multi_node/lite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -30,8 +32,16 @@ def run(
node_rank: int,
nprocs: int,
):
from lightning.lite import LightningLite
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
lites = []
strategies = []
mps_accelerators = []

for pkg_name in ("lightning.lite", "lightning_" + "lite"):
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand All @@ -52,23 +62,36 @@ def run(
def pre_fn(lite, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"

if any(acc.is_available() for acc in mps_accelerators):
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
else:
kwargs["accelerator"] = "auto"
strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
if strategy == "ddp_spawn":
strategy = "ddp"
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
elif isinstance(strategy, tuple(strategies)):
raise ValueError("DDP Spawned strategies aren't supported yet.")

kwargs["strategy"] = strategy

return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn)
for ll in lites:
tracer.add_traced(ll, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LiteMultiNode(MultiNode):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/components/multi_node/pytorch_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run(
elif world_size > 1:
raise Exception("Torch distributed should be available.")

work_run(world_size, node_rank, global_rank, local_rank)
return work_run(world_size, node_rank, global_rank, local_rank)


class PyTorchSpawnMultiNode(MultiNode):
Expand Down
37 changes: 28 additions & 9 deletions src/lightning_app/components/multi_node/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -30,9 +32,16 @@ def run(
node_rank: int,
nprocs: int,
):
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
from lightning.pytorch import Trainer as LTrainer
from pytorch_lightning import Trainer as PLTrainer
trainers = []
strategies = []
mps_accelerators = []

for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand All @@ -50,24 +59,34 @@ def run(
def pre_fn(trainer, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"
if any(acc.is_available() for acc in mps_accelerators):
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
else:
kwargs["accelerator"] = "auto"

strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
if strategy == "ddp_spawn":
strategy = "ddp"
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
elif isinstance(strategy, tuple(strategies)):
raise ValueError("DDP Spawned strategies aren't supported yet.")
kwargs["strategy"] = strategy
return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn)
tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn)
for trainer in trainers:
tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LightningTrainerMultiNode(MultiNode):
Expand Down
Empty file.
103 changes: 103 additions & 0 deletions tests/tests_app/components/multi_node/test_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import lightning_lite as ll
from lightning_app.components.multi_node.lite import _LiteRunExecutor


class DummyLite(ll.LightningLite):
def run(self):
pass


def dummy_callable(**kwargs):
lite = DummyLite(**kwargs)
return lite._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(ll.LightningLite, "__init__", dummy_init):
ret_val = _LiteRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


def check_lightning_lite_mps():
if module_available("lightning_lite"):
return ll.accelerators.MPSAccelerator.is_available()
return False


@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test")
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):

# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
if ll.accelerators.MPSAccelerator.is_available():
args_expected["accelerator"] = "cpu"

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
assert env_vars["LT_CLI_USED"] == "1"


@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available")
def test_lite_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy())
99 changes: 99 additions & 0 deletions tests/tests_app/components/multi_node/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import pytorch_lightning as pl
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor


def dummy_callable(**kwargs):
t = pl.Trainer(**kwargs)
return t._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(pl.Trainer, "__init__", dummy_init):
ret_val = _LightningTrainerRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


def check_lightning_pytorch_and_mps():
if module_available("pytorch_lightning"):
return pl.accelerators.MPSAccelerator.is_available()
return False


@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
@pytest.mark.skipif(not module_available("pytorch"), reason="Lightning is not available")
def test_trainer_run_executor_arguments_choices(
args_given: dict,
args_expected: dict,
):

if pl.accelerators.MPSAccelerator.is_available():
args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"


@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available")
def test_trainer_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy())
4 changes: 2 additions & 2 deletions tests/tests_app/utilities/packaging/test_build_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_build_config_requirements_provided():
assert spec.requirements == [
"dask",
"pandas",
"pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string`
"pytorch_lightning==1.5.9",
"git+https://github.com/mit-han-lab/[email protected]",
]
assert spec == BuildConfig.from_dict(spec.to_dict())
Expand All @@ -50,7 +50,7 @@ def test_build_config_dockerfile_provided():
spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu")
assert not spec.requirements
# ugly hack due to replacing `pytorch_lightning string
assert "pytorchlightning/pytorch_" + "lightning" in spec.dockerfile.data[0]
assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0]


class DockerfileLightningTestApp(LightningTestApp):
Expand Down

0 comments on commit 36aecde

Please sign in to comment.