diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 305ee591b0257..053d7b1c62ea1 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -36,8 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747)) - - +- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810)) - Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801)) diff --git a/src/lightning_app/cli/cmd_apps.py b/src/lightning_app/cli/cmd_apps.py index ec691f07a0142..3edc97449836b 100644 --- a/src/lightning_app/cli/cmd_apps.py +++ b/src/lightning_app/cli/cmd_apps.py @@ -54,9 +54,13 @@ def list_apps( apps = apps + resp.lightningapps return apps - def list_components(self, app_id: str) -> List[Externalv1Lightningwork]: + def list_components(self, app_id: str, phase_in: List[str] = []) -> List[Externalv1Lightningwork]: project = _get_project(self.api_client) - resp = self.api_client.lightningwork_service_list_lightningwork(project_id=project.project_id, app_id=app_id) + resp = self.api_client.lightningwork_service_list_lightningwork( + project_id=project.project_id, + app_id=app_id, + phase_in=phase_in, + ) return resp.lightningworks def list(self, cluster_id: str = None, limit: int = 100) -> None: diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 9f6682bce6a75..0205405c20560 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -8,7 +8,7 @@ import click import inquirer import rich -from lightning_cloud.openapi import Externalv1LightningappInstance, V1LightningappInstanceState +from lightning_cloud.openapi import Externalv1LightningappInstance, V1LightningappInstanceState, V1LightningworkState from lightning_cloud.openapi.rest import ApiException from lightning_utilities.core.imports import RequirementCache from requests.exceptions import ConnectionError @@ -420,7 +420,7 @@ def ssh(app_name: str = None, component_name: str = None) -> None: except ApiException: raise click.ClickException("failed fetching app instance") - components = app_manager.list_components(app_id=app_id) + components = app_manager.list_components(app_id=app_id, phase_in=[V1LightningworkState.RUNNING]) available_component_names = [work.name for work in components] + ["flow"] if component_name is None: available_components = [ diff --git a/tests/tests_app/cli/test_cmd_apps.py b/tests/tests_app/cli/test_cmd_apps.py index 1cfb35893cc7d..e579c673ac3d6 100644 --- a/tests/tests_app/cli/test_cmd_apps.py +++ b/tests/tests_app/cli/test_cmd_apps.py @@ -7,7 +7,9 @@ V1LightningappInstanceSpec, V1LightningappInstanceState, V1LightningappInstanceStatus, + V1LightningworkState, V1ListLightningappInstancesResponse, + V1ListLightningworkResponse, V1ListMembershipsResponse, V1Membership, ) @@ -97,6 +99,36 @@ def test_list_all_apps(list_memberships: mock.MagicMock, list_instances: mock.Ma list_instances.assert_called_once_with(project_id="default-project", limit=100, phase_in=[]) +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.utilities.network.LightningClient.lightningwork_service_list_lightningwork") +@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships") +def test_list_components(list_memberships: mock.MagicMock, list_components: mock.MagicMock): + list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")]) + list_components.return_value = V1ListLightningworkResponse(lightningworks=[]) + + cluster_manager = _AppManager() + cluster_manager.list_components(app_id="cheese") + + list_memberships.assert_called_once() + list_components.assert_called_once_with(project_id="default-project", app_id="cheese", phase_in=[]) + + +@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) +@mock.patch("lightning_app.utilities.network.LightningClient.lightningwork_service_list_lightningwork") +@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships") +def test_list_components_with_phase(list_memberships: mock.MagicMock, list_components: mock.MagicMock): + list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")]) + list_components.return_value = V1ListLightningworkResponse(lightningworks=[]) + + cluster_manager = _AppManager() + cluster_manager.list_components(app_id="cheese", phase_in=[V1LightningworkState.RUNNING]) + + list_memberships.assert_called_once() + list_components.assert_called_once_with( + project_id="default-project", app_id="cheese", phase_in=[V1LightningworkState.RUNNING] + ) + + @mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) @mock.patch("lightning_app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances") @mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")