Skip to content

Commit

Permalink
Reuse existing commands when running connect more than once (#15471)
Browse files Browse the repository at this point in the history
* Reuse connection if it matches a connection from an active terminal
* Remove unused import
* Include both name and id in the check
* Fix messages and tests
* Add test
* Handle monkeypatching more cleanly
* Remove unused imports

Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Nov 7, 2022
1 parent 94c300c commit 01f57a9
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 19 deletions.
62 changes: 55 additions & 7 deletions src/lightning_app/cli/commands/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def connect(app_name_or_id: str, yes: bool = False):

connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")

matched_connection_path = _scan_lightning_connections(app_name_or_id)

if os.path.exists(connected_file):
with open(connected_file) as f:
result = f.readlines()[0].replace("\n", "")
Expand Down Expand Up @@ -79,17 +81,36 @@ def connect(app_name_or_id: str, yes: bool = False):
target_file=target_file,
)
repr_command_name = command_name.replace("_", " ")
click.echo(f"Find the `{repr_command_name}` command under {target_file}.")
click.echo(f"Storing `{repr_command_name}` at {target_file}")
else:
with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f:
f.write(command_name)

click.echo(f"You can review all the downloaded commands under {commands_folder} folder.")
click.echo(f"You can review all the downloaded commands at {commands_folder}")

with open(connected_file, "w") as f:
f.write(app_name_or_id + "\n")

click.echo("You are connected to the local Lightning App.")

elif matched_connection_path:

matched_connected_file = os.path.join(matched_connection_path, "connect.txt")
matched_commands = os.path.join(matched_connection_path, "commands")
if os.path.isdir(matched_commands):
commands = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
shutil.copytree(matched_commands, commands)
shutil.copy(matched_connected_file, connected_file)
copied_files = [el for el in os.listdir(commands) if os.path.splitext(el)[1] == ".py"]
click.echo("Found existing connection, reusing cached commands")
for target_file in copied_files:
pretty_command_name = os.path.splitext(target_file)[0].replace("_", " ")
click.echo(f"Storing `{pretty_command_name}` at {os.path.join(commands, target_file)}")

click.echo(f"You can review all the commands at {commands}")
click.echo(" ")
click.echo(f"You are connected to the cloud Lightning App: {app_name_or_id}.")

else:

retriever = _LightningAppOpenAPIRetriever(app_name_or_id)
Expand Down Expand Up @@ -131,12 +152,12 @@ def connect(app_name_or_id: str, yes: bool = False):
target_file=target_file,
)
pretty_command_name = command_name.replace("_", " ")
click.echo(f"Storing `{pretty_command_name}` under {target_file}")
click.echo(f"Storing `{pretty_command_name}` at {target_file}")
else:
with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f:
f.write(command_name)

click.echo(f"You can review all the downloaded commands under {commands_folder} folder.")
click.echo(f"You can review all the downloaded commands at {commands_folder}")

click.echo(" ")
click.echo("The client interface has been successfully installed. ")
Expand Down Expand Up @@ -178,9 +199,7 @@ def disconnect(logout: bool = False):
)


def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]:
connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")

def _read_connected_file(connected_file):
if os.path.exists(connected_file):
with open(connected_file) as f:
lines = [line.replace("\n", "") for line in f.readlines()]
Expand All @@ -190,6 +209,11 @@ def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]:
return None, None


def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]:
connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")
return _read_connected_file(connected_file)


def _get_commands_folder() -> str:
return os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")

Expand Down Expand Up @@ -284,3 +308,27 @@ def _clean_lightning_connection():
connection = os.path.join(_LIGHTNING_CONNECTION, str(ppid))
if os.path.exists(connection):
shutil.rmtree(connection)


def _scan_lightning_connections(app_name_or_id):
if not os.path.exists(_LIGHTNING_CONNECTION):
return

for ppid in os.listdir(_LIGHTNING_CONNECTION):
try:
psutil.Process(int(ppid))
except (psutil.NoSuchProcess, ValueError):
continue

connection_path = os.path.join(_LIGHTNING_CONNECTION, str(ppid))

connected_file = os.path.join(connection_path, "connect.txt")
curr_app_name, curr_app_id = _read_connected_file(connected_file)

if not curr_app_name:
continue

if app_name_or_id == curr_app_name or app_name_or_id == curr_app_id:
return connection_path

return None
57 changes: 45 additions & 12 deletions tests/tests_app/cli/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from unittest.mock import MagicMock

import click
import psutil
import pytest

from lightning_app import _PROJECT_ROOT
from lightning_app.cli.commands.connection import (
_list_app_commands,
_PPID,
_resolve_command_path,
_retrieve_connection_to_an_app,
connect,
Expand All @@ -19,10 +19,27 @@
from lightning_app.utilities.commands import base


def test_connect_disconnect_local(monkeypatch):
def monkeypatch_connection(monkeypatch, tmpdir, ppid):
connection_path = os.path.join(tmpdir, ppid)
try:
monkeypatch.setattr("lightning_app.cli.commands.connection._clean_lightning_connection", MagicMock())
monkeypatch.setattr("lightning_app.cli.commands.connection._PPID", ppid)
monkeypatch.setattr("lightning_app.cli.commands.connection._LIGHTNING_CONNECTION", tmpdir)
monkeypatch.setattr("lightning_app.cli.commands.connection._LIGHTNING_CONNECTION_FOLDER", connection_path)
except ModuleNotFoundError:
monkeypatch.setattr("lightning.app.cli.commands.connection._clean_lightning_connection", MagicMock())
monkeypatch.setattr("lightning_app.cli.commands.connection._PPID", ppid)
monkeypatch.setattr("lightning.app.cli.commands.connection._LIGHTNING_CONNECTION", tmpdir)
monkeypatch.setattr("lightning.app.cli.commands.connection._LIGHTNING_CONNECTION_FOLDER", connection_path)
return connection_path


def test_connect_disconnect_local(tmpdir, monkeypatch):
disconnect()

ppid = str(psutil.Process(os.getpid()).ppid())
connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid)

with pytest.raises(Exception, match="The commands weren't found. Is your app localhost running ?"):
connect("localhost", True)

Expand Down Expand Up @@ -53,12 +70,11 @@ def fn(msg):
assert not os.path.exists(command_path)
command_path = _resolve_command_path("command_with_client")
assert os.path.exists(command_path)
home = os.path.expanduser("~")
s = "/" if sys.platform != "win32" else "\\"
command_folder_path = f"{home}{s}.lightning{s}lightning_connection{s}{_PPID}{s}commands"
command_folder_path = f"{connection_path}{s}commands"
expected = [
f"Find the `command with client` command under {command_folder_path}{s}command_with_client.py.",
f"You can review all the downloaded commands under {command_folder_path} folder.",
f"Storing `command with client` at {command_folder_path}{s}command_with_client.py",
f"You can review all the downloaded commands at {command_folder_path}",
"You are connected to the local Lightning App.",
]
assert messages == expected
Expand All @@ -79,10 +95,14 @@ def fn(msg):
assert _retrieve_connection_to_an_app() == (None, None)


def test_connect_disconnect_cloud(monkeypatch):

def test_connect_disconnect_cloud(tmpdir, monkeypatch):
disconnect()

ppid_1 = str(psutil.Process(os.getpid()).ppid())
ppid_2 = "222"

connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1)

target_file = _resolve_command_path("command_with_client")

if os.path.exists(target_file):
Expand Down Expand Up @@ -141,12 +161,11 @@ def fn(msg):
assert not os.path.exists(command_path)
command_path = _resolve_command_path("command_with_client")
assert os.path.exists(command_path)
home = os.path.expanduser("~")
s = "/" if sys.platform != "win32" else "\\"
command_folder_path = f"{home}{s}.lightning{s}lightning_connection{s}{_PPID}{s}commands"
command_folder_path = f"{connection_path}{s}commands"
expected = [
f"Storing `command with client` under {command_folder_path}{s}command_with_client.py",
f"You can review all the downloaded commands under {command_folder_path} folder.",
f"Storing `command with client` at {command_folder_path}{s}command_with_client.py",
f"You can review all the downloaded commands at {command_folder_path}",
" ",
"The client interface has been successfully installed. ",
"You can now run the following commands:",
Expand All @@ -170,9 +189,23 @@ def fn(msg):
connect("example", True)
assert messages == ["You are already connected to the cloud Lightning App: example."]

_ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_2)

messages = []
connect("example", True)
assert messages[0] == "Found existing connection, reusing cached commands"

messages = []
disconnect()
print(messages)
assert messages == ["You are disconnected from the cloud Lightning App: example."]

_ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1)

messages = []
disconnect()
assert messages == ["You are disconnected from the cloud Lightning App: example."]

messages = []
disconnect()
assert messages == [
Expand Down

0 comments on commit 01f57a9

Please sign in to comment.