diff --git a/src/lightning_app/cli/commands/connection.py b/src/lightning_app/cli/commands/connection.py index 5027b33e51f0f..ee0bf7edc5d67 100644 --- a/src/lightning_app/cli/commands/connection.py +++ b/src/lightning_app/cli/commands/connection.py @@ -34,6 +34,8 @@ def connect(app_name_or_id: str, yes: bool = False): connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt") + matched_connection_path = _scan_lightning_connections(app_name_or_id) + if os.path.exists(connected_file): with open(connected_file) as f: result = f.readlines()[0].replace("\n", "") @@ -79,17 +81,36 @@ def connect(app_name_or_id: str, yes: bool = False): target_file=target_file, ) repr_command_name = command_name.replace("_", " ") - click.echo(f"Find the `{repr_command_name}` command under {target_file}.") + click.echo(f"Storing `{repr_command_name}` at {target_file}") else: with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f: f.write(command_name) - click.echo(f"You can review all the downloaded commands under {commands_folder} folder.") + click.echo(f"You can review all the downloaded commands at {commands_folder}") with open(connected_file, "w") as f: f.write(app_name_or_id + "\n") click.echo("You are connected to the local Lightning App.") + + elif matched_connection_path: + + matched_connected_file = os.path.join(matched_connection_path, "connect.txt") + matched_commands = os.path.join(matched_connection_path, "commands") + if os.path.isdir(matched_commands): + commands = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") + shutil.copytree(matched_commands, commands) + shutil.copy(matched_connected_file, connected_file) + copied_files = [el for el in os.listdir(commands) if os.path.splitext(el)[1] == ".py"] + click.echo("Found existing connection, reusing cached commands") + for target_file in copied_files: + pretty_command_name = os.path.splitext(target_file)[0].replace("_", " ") + click.echo(f"Storing `{pretty_command_name}` at {os.path.join(commands, target_file)}") + + click.echo(f"You can review all the commands at {commands}") + click.echo(" ") + click.echo(f"You are connected to the cloud Lightning App: {app_name_or_id}.") + else: retriever = _LightningAppOpenAPIRetriever(app_name_or_id) @@ -131,12 +152,12 @@ def connect(app_name_or_id: str, yes: bool = False): target_file=target_file, ) pretty_command_name = command_name.replace("_", " ") - click.echo(f"Storing `{pretty_command_name}` under {target_file}") + click.echo(f"Storing `{pretty_command_name}` at {target_file}") else: with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f: f.write(command_name) - click.echo(f"You can review all the downloaded commands under {commands_folder} folder.") + click.echo(f"You can review all the downloaded commands at {commands_folder}") click.echo(" ") click.echo("The client interface has been successfully installed. ") @@ -178,9 +199,7 @@ def disconnect(logout: bool = False): ) -def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]: - connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt") - +def _read_connected_file(connected_file): if os.path.exists(connected_file): with open(connected_file) as f: lines = [line.replace("\n", "") for line in f.readlines()] @@ -190,6 +209,11 @@ def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]: return None, None +def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]: + connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt") + return _read_connected_file(connected_file) + + def _get_commands_folder() -> str: return os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") @@ -284,3 +308,27 @@ def _clean_lightning_connection(): connection = os.path.join(_LIGHTNING_CONNECTION, str(ppid)) if os.path.exists(connection): shutil.rmtree(connection) + + +def _scan_lightning_connections(app_name_or_id): + if not os.path.exists(_LIGHTNING_CONNECTION): + return + + for ppid in os.listdir(_LIGHTNING_CONNECTION): + try: + psutil.Process(int(ppid)) + except (psutil.NoSuchProcess, ValueError): + continue + + connection_path = os.path.join(_LIGHTNING_CONNECTION, str(ppid)) + + connected_file = os.path.join(connection_path, "connect.txt") + curr_app_name, curr_app_id = _read_connected_file(connected_file) + + if not curr_app_name: + continue + + if app_name_or_id == curr_app_name or app_name_or_id == curr_app_id: + return connection_path + + return None diff --git a/tests/tests_app/cli/test_connect.py b/tests/tests_app/cli/test_connect.py index 605e1f97af7b5..a8924ab375db2 100644 --- a/tests/tests_app/cli/test_connect.py +++ b/tests/tests_app/cli/test_connect.py @@ -4,12 +4,12 @@ from unittest.mock import MagicMock import click +import psutil import pytest from lightning_app import _PROJECT_ROOT from lightning_app.cli.commands.connection import ( _list_app_commands, - _PPID, _resolve_command_path, _retrieve_connection_to_an_app, connect, @@ -19,10 +19,27 @@ from lightning_app.utilities.commands import base -def test_connect_disconnect_local(monkeypatch): +def monkeypatch_connection(monkeypatch, tmpdir, ppid): + connection_path = os.path.join(tmpdir, ppid) + try: + monkeypatch.setattr("lightning_app.cli.commands.connection._clean_lightning_connection", MagicMock()) + monkeypatch.setattr("lightning_app.cli.commands.connection._PPID", ppid) + monkeypatch.setattr("lightning_app.cli.commands.connection._LIGHTNING_CONNECTION", tmpdir) + monkeypatch.setattr("lightning_app.cli.commands.connection._LIGHTNING_CONNECTION_FOLDER", connection_path) + except ModuleNotFoundError: + monkeypatch.setattr("lightning.app.cli.commands.connection._clean_lightning_connection", MagicMock()) + monkeypatch.setattr("lightning_app.cli.commands.connection._PPID", ppid) + monkeypatch.setattr("lightning.app.cli.commands.connection._LIGHTNING_CONNECTION", tmpdir) + monkeypatch.setattr("lightning.app.cli.commands.connection._LIGHTNING_CONNECTION_FOLDER", connection_path) + return connection_path + +def test_connect_disconnect_local(tmpdir, monkeypatch): disconnect() + ppid = str(psutil.Process(os.getpid()).ppid()) + connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid) + with pytest.raises(Exception, match="The commands weren't found. Is your app localhost running ?"): connect("localhost", True) @@ -53,12 +70,11 @@ def fn(msg): assert not os.path.exists(command_path) command_path = _resolve_command_path("command_with_client") assert os.path.exists(command_path) - home = os.path.expanduser("~") s = "/" if sys.platform != "win32" else "\\" - command_folder_path = f"{home}{s}.lightning{s}lightning_connection{s}{_PPID}{s}commands" + command_folder_path = f"{connection_path}{s}commands" expected = [ - f"Find the `command with client` command under {command_folder_path}{s}command_with_client.py.", - f"You can review all the downloaded commands under {command_folder_path} folder.", + f"Storing `command with client` at {command_folder_path}{s}command_with_client.py", + f"You can review all the downloaded commands at {command_folder_path}", "You are connected to the local Lightning App.", ] assert messages == expected @@ -79,10 +95,14 @@ def fn(msg): assert _retrieve_connection_to_an_app() == (None, None) -def test_connect_disconnect_cloud(monkeypatch): - +def test_connect_disconnect_cloud(tmpdir, monkeypatch): disconnect() + ppid_1 = str(psutil.Process(os.getpid()).ppid()) + ppid_2 = "222" + + connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1) + target_file = _resolve_command_path("command_with_client") if os.path.exists(target_file): @@ -141,12 +161,11 @@ def fn(msg): assert not os.path.exists(command_path) command_path = _resolve_command_path("command_with_client") assert os.path.exists(command_path) - home = os.path.expanduser("~") s = "/" if sys.platform != "win32" else "\\" - command_folder_path = f"{home}{s}.lightning{s}lightning_connection{s}{_PPID}{s}commands" + command_folder_path = f"{connection_path}{s}commands" expected = [ - f"Storing `command with client` under {command_folder_path}{s}command_with_client.py", - f"You can review all the downloaded commands under {command_folder_path} folder.", + f"Storing `command with client` at {command_folder_path}{s}command_with_client.py", + f"You can review all the downloaded commands at {command_folder_path}", " ", "The client interface has been successfully installed. ", "You can now run the following commands:", @@ -170,9 +189,23 @@ def fn(msg): connect("example", True) assert messages == ["You are already connected to the cloud Lightning App: example."] + _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_2) + + messages = [] + connect("example", True) + assert messages[0] == "Found existing connection, reusing cached commands" + messages = [] disconnect() + print(messages) assert messages == ["You are disconnected from the cloud Lightning App: example."] + + _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1) + + messages = [] + disconnect() + assert messages == ["You are disconnected from the cloud Lightning App: example."] + messages = [] disconnect() assert messages == [