-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix restarting attribute for lr finder * update lite executor * update trainer executor * update spawn executor * add multinode component tests * add testing helpers * add lite tests * add trainer tests * update changelog * update trainer * update workflow * update tests * debug * add reason for skipif * Apply suggestions from code review * switch skipif Co-authored-by: Jirka <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
- Loading branch information
1 parent
8475f85
commit 36aecde
Showing
9 changed files
with
268 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
from copy import deepcopy | ||
from functools import partial | ||
from unittest import mock | ||
|
||
import pytest | ||
from lightning_utilities.core.imports import module_available | ||
from tests_app.helpers.utils import no_warning_call | ||
|
||
import lightning_lite as ll | ||
from lightning_app.components.multi_node.lite import _LiteRunExecutor | ||
|
||
|
||
class DummyLite(ll.LightningLite): | ||
def run(self): | ||
pass | ||
|
||
|
||
def dummy_callable(**kwargs): | ||
lite = DummyLite(**kwargs) | ||
return lite._all_passed_kwargs | ||
|
||
|
||
def dummy_init(self, **kwargs): | ||
self._all_passed_kwargs = kwargs | ||
|
||
|
||
def _get_args_after_tracer_injection(**kwargs): | ||
with mock.patch.object(ll.LightningLite, "__init__", dummy_init): | ||
ret_val = _LiteRunExecutor.run( | ||
local_rank=0, | ||
work_run=partial(dummy_callable, **kwargs), | ||
main_address="1.2.3.4", | ||
main_port=5, | ||
node_rank=6, | ||
num_nodes=7, | ||
nprocs=8, | ||
) | ||
env_vars = deepcopy(os.environ) | ||
return ret_val, env_vars | ||
|
||
|
||
def check_lightning_lite_mps(): | ||
if module_available("lightning_lite"): | ||
return ll.accelerators.MPSAccelerator.is_available() | ||
return False | ||
|
||
|
||
@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available") | ||
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) | ||
def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): | ||
warning_str = ( | ||
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " | ||
+ "by PyTorch for distributed training on mps capable devices" | ||
) | ||
if accelerator_expected != accelerator_given: | ||
warning_context = pytest.warns(UserWarning, match=warning_str) | ||
else: | ||
warning_context = no_warning_call(match=warning_str + "*") | ||
|
||
with warning_context: | ||
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) | ||
assert ret_val["accelerator"] == accelerator_expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"args_given,args_expected", | ||
[ | ||
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), | ||
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), | ||
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), | ||
], | ||
) | ||
@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test") | ||
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict): | ||
|
||
# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args) | ||
if ll.accelerators.MPSAccelerator.is_available(): | ||
args_expected["accelerator"] = "cpu" | ||
|
||
ret_val, env_vars = _get_args_after_tracer_injection(**args_given) | ||
|
||
for k, v in args_expected.items(): | ||
assert ret_val[k] == v | ||
|
||
assert env_vars["MASTER_ADDR"] == "1.2.3.4" | ||
assert env_vars["MASTER_PORT"] == "5" | ||
assert env_vars["GROUP_RANK"] == "6" | ||
assert env_vars["RANK"] == str(0 + 6 * 8) | ||
assert env_vars["LOCAL_RANK"] == "0" | ||
assert env_vars["WORLD_SIZE"] == str(7 * 8) | ||
assert env_vars["LOCAL_WORLD_SIZE"] == "8" | ||
assert env_vars["TORCHELASTIC_RUN_ID"] == "1" | ||
assert env_vars["LT_CLI_USED"] == "1" | ||
|
||
|
||
@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available") | ||
def test_lite_run_executor_invalid_strategy_instances(): | ||
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): | ||
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy()) | ||
|
||
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): | ||
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
from copy import deepcopy | ||
from functools import partial | ||
from unittest import mock | ||
|
||
import pytest | ||
from lightning_utilities.core.imports import module_available | ||
from tests_app.helpers.utils import no_warning_call | ||
|
||
import pytorch_lightning as pl | ||
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor | ||
|
||
|
||
def dummy_callable(**kwargs): | ||
t = pl.Trainer(**kwargs) | ||
return t._all_passed_kwargs | ||
|
||
|
||
def dummy_init(self, **kwargs): | ||
self._all_passed_kwargs = kwargs | ||
|
||
|
||
def _get_args_after_tracer_injection(**kwargs): | ||
with mock.patch.object(pl.Trainer, "__init__", dummy_init): | ||
ret_val = _LightningTrainerRunExecutor.run( | ||
local_rank=0, | ||
work_run=partial(dummy_callable, **kwargs), | ||
main_address="1.2.3.4", | ||
main_port=5, | ||
node_rank=6, | ||
num_nodes=7, | ||
nprocs=8, | ||
) | ||
env_vars = deepcopy(os.environ) | ||
return ret_val, env_vars | ||
|
||
|
||
def check_lightning_pytorch_and_mps(): | ||
if module_available("pytorch_lightning"): | ||
return pl.accelerators.MPSAccelerator.is_available() | ||
return False | ||
|
||
|
||
@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required") | ||
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) | ||
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): | ||
warning_str = ( | ||
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " | ||
+ "by PyTorch for distributed training on mps capable devices" | ||
) | ||
if accelerator_expected != accelerator_given: | ||
warning_context = pytest.warns(UserWarning, match=warning_str) | ||
else: | ||
warning_context = no_warning_call(match=warning_str + "*") | ||
|
||
with warning_context: | ||
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) | ||
assert ret_val["accelerator"] == accelerator_expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"args_given,args_expected", | ||
[ | ||
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), | ||
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), | ||
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), | ||
], | ||
) | ||
@pytest.mark.skipif(not module_available("pytorch"), reason="Lightning is not available") | ||
def test_trainer_run_executor_arguments_choices( | ||
args_given: dict, | ||
args_expected: dict, | ||
): | ||
|
||
if pl.accelerators.MPSAccelerator.is_available(): | ||
args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately | ||
|
||
ret_val, env_vars = _get_args_after_tracer_injection(**args_given) | ||
|
||
for k, v in args_expected.items(): | ||
assert ret_val[k] == v | ||
|
||
assert env_vars["MASTER_ADDR"] == "1.2.3.4" | ||
assert env_vars["MASTER_PORT"] == "5" | ||
assert env_vars["GROUP_RANK"] == "6" | ||
assert env_vars["RANK"] == str(0 + 6 * 8) | ||
assert env_vars["LOCAL_RANK"] == "0" | ||
assert env_vars["WORLD_SIZE"] == str(7 * 8) | ||
assert env_vars["LOCAL_WORLD_SIZE"] == "8" | ||
assert env_vars["TORCHELASTIC_RUN_ID"] == "1" | ||
|
||
|
||
@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available") | ||
def test_trainer_run_executor_invalid_strategy_instances(): | ||
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): | ||
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy()) | ||
|
||
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): | ||
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ def test_build_config_requirements_provided(): | |
assert spec.requirements == [ | ||
"dask", | ||
"pandas", | ||
"pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string` | ||
"pytorch_lightning==1.5.9", | ||
"git+https://github.com/mit-han-lab/[email protected]", | ||
] | ||
assert spec == BuildConfig.from_dict(spec.to_dict()) | ||
|
@@ -50,7 +50,7 @@ def test_build_config_dockerfile_provided(): | |
spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu") | ||
assert not spec.requirements | ||
# ugly hack due to replacing `pytorch_lightning string | ||
assert "pytorchlightning/pytorch_" + "lightning" in spec.dockerfile.data[0] | ||
assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0] | ||
|
||
|
||
class DockerfileLightningTestApp(LightningTestApp): | ||
|