From a970f090a09e7e9e345bedd4c4ab8e436f9f6d1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 24 Nov 2022 18:21:32 +0100 Subject: [PATCH] Ignore `num_nodes` when running MultiNode components locally (#15806) --- src/lightning_app/CHANGELOG.md | 1 + .../components/multi_node/base.py | 15 ++++++++-- .../components/multi_node/test_base.py | 19 ++++++++++++ tests/tests_app/helpers/__init__.py | 0 tests/tests_app/helpers/utils.py | 30 +++++++++++++++++++ .../public/test_multi_node.py | 7 +++-- 6 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 tests/tests_app/components/multi_node/test_base.py create mode 100644 tests/tests_app/helpers/__init__.py create mode 100644 tests/tests_app/helpers/utils.py diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 971d0ee9b19d7..305ee591b0257 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `lightning add ssh-key` CLI command has been transitioned to `lightning create ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761)) - `lightning remove ssh-key` CLI command has been transitioned to `lightning delete ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761)) +- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806)) ### Deprecated diff --git a/src/lightning_app/components/multi_node/base.py b/src/lightning_app/components/multi_node/base.py index 4f2005771212a..ee4f2b3abd4fb 100644 --- a/src/lightning_app/components/multi_node/base.py +++ b/src/lightning_app/components/multi_node/base.py @@ -1,8 +1,10 @@ +import warnings from typing import Any, Type from lightning_app import structures from lightning_app.core.flow import LightningFlow from lightning_app.core.work import LightningWork +from lightning_app.utilities.cloud import is_running_in_cloud from lightning_app.utilities.packaging.cloud_compute import CloudCompute @@ -45,12 +47,21 @@ def run( Arguments: work_cls: The work to be executed - num_nodes: Number of nodes. - cloud_compute: The cloud compute object used in the cloud. + num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on + multiple cloud machines. + cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when + running locally. work_args: Arguments to be provided to the work on instantiation. work_kwargs: Keywords arguments to be provided to the work on instantiation. """ super().__init__() + if num_nodes > 1 and not is_running_in_cloud(): + num_nodes = 1 + warnings.warn( + f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally." + " We assume you are debugging and will ignore the `num_nodes` argument." + " To run on multiple nodes in the cloud, launch your app with `--cloud`." + ) self.ws = structures.List( *[ work_cls( diff --git a/tests/tests_app/components/multi_node/test_base.py b/tests/tests_app/components/multi_node/test_base.py new file mode 100644 index 0000000000000..e23535fbfe970 --- /dev/null +++ b/tests/tests_app/components/multi_node/test_base.py @@ -0,0 +1,19 @@ +from re import escape + +import pytest +from tests_app.helpers.utils import no_warning_call + +from lightning_app import CloudCompute, LightningWork +from lightning_app.components import MultiNode + + +def test_multi_node_warn_running_locally(): + class Work(LightningWork): + def run(self): + pass + + with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")): + MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu")) + + with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")): + MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu")) diff --git a/tests/tests_app/helpers/__init__.py b/tests/tests_app/helpers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_app/helpers/utils.py b/tests/tests_app/helpers/utils.py new file mode 100644 index 0000000000000..00868f799a952 --- /dev/null +++ b/tests/tests_app/helpers/utils.py @@ -0,0 +1,30 @@ +import re +from contextlib import contextmanager +from typing import Optional, Type + +import pytest + + +@contextmanager +def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None): + # TODO: Replace with `lightning_utilities.test.warning.no_warning_call` + # https://github.com/Lightning-AI/utilities/issues/57 + + with pytest.warns(None) as record: + yield + + if match is None: + try: + w = record.pop(expected_warning) + except AssertionError: + # no warning raised + return + else: + for w in record.list: + if w.category is expected_warning and re.compile(match).search(w.message.args[0]): + break + else: + return + + msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`" + raise AssertionError(f"{msg} was raised: {w}") diff --git a/tests/tests_examples_app/public/test_multi_node.py b/tests/tests_examples_app/public/test_multi_node.py index 79fd491b98751..a5cd2de40811f 100644 --- a/tests/tests_examples_app/public/test_multi_node.py +++ b/tests/tests_examples_app/public/test_multi_node.py @@ -1,5 +1,6 @@ import os import sys +from unittest import mock import pytest from tests_examples_app.public import _PATH_EXAMPLES @@ -17,7 +18,8 @@ def on_before_run_once(self): @pytest.mark.skip(reason="flaky") -def test_multi_node_example(monkeypatch): +@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True) +def test_multi_node_example(_, monkeypatch): monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node")) command_line = [ "app.py", @@ -50,7 +52,8 @@ def on_before_run_once(self): ], ) @pytest.mark.skipif(sys.platform == "win32", reason="flaky") -def test_multi_node_examples(app_name, monkeypatch): +@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True) +def test_multi_node_examples(_, app_name, monkeypatch): monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node")) command_line = [ app_name,