Skip to content

Commit

Permalink
App: Fix dispatch return value (#18674)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 29, 2023
1 parent bd784d3 commit 02780f2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/lightning/app/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class _Run(BaseModel):


def _run_plugin(run: _Run) -> Dict[str, Any]:
from lightning.app.runners.cloud import _to_clean_dict

"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
with tempfile.TemporaryDirectory() as tmpdir:
download_path = os.path.join(tmpdir, "source.tar.gz")
Expand Down Expand Up @@ -184,8 +186,8 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
cluster_id=run.cluster_id,
source_app=run.source_app,
)
appInstance = plugin.run(**run.plugin_arguments)
return {"appInstance": appInstance.to_dict()}
app_instance = plugin.run(**run.plugin_arguments)
return _to_clean_dict(app_instance, True)
except Exception as ex:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(ex)}."
Expand Down
9 changes: 3 additions & 6 deletions src/lightning/app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def cloudspace_dispatch(
name: str,
cluster_id: str,
source_app: Optional[str] = None,
) -> str:
) -> Externalv1LightningappInstance:
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such
as the project and cluster IDs that are instead passed directly.
Expand All @@ -213,7 +213,7 @@ def cloudspace_dispatch(
ValueError: If there are validation errors.
Returns:
The URL of the created job.
The spec the created app instance.
"""
# Dispatch in four phases: resolution, validation, spec creation, API transactions
Expand All @@ -230,7 +230,6 @@ def cloudspace_dispatch(
package_source=not absolute_entrypoint,
sys_customizations_root=sys_customizations_root,
)
project = self._resolve_project(project_id=project_id)
existing_instances = self._resolve_run_instances_by_name(project_id, name)
name = self._resolve_run_name(name, existing_instances)
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
Expand Down Expand Up @@ -269,7 +268,7 @@ def cloudspace_dispatch(
self._api_package_and_upload_repo(repo, run)

logger.info(f"Creating cloudspace run instance. name: {name}")
run_instance = self._api_create_run_instance(
return self._api_create_run_instance(
cluster_id,
project_id,
name,
Expand All @@ -281,8 +280,6 @@ def cloudspace_dispatch(
source_app=source_app,
)

return self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui")

def dispatch(
self,
name: str = "",
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_app/plugin/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def test_run_job(mock_requests, mock_cloud_runtime, mock_cloud_backend, mock_plu

response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True))

assert response.status_code == status.HTTP_200_OK
assert json.loads(response.text)["appInstance"]["id"] == "created_app_id"
assert response.status_code == status.HTTP_200_OK, response.json()
assert json.loads(response.text)["id"] == "created_app_id"

mock_cloud_runtime.load_app_from_file.assert_called_once()
assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0]
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_app/runners/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,8 @@ def test_cloudspace_dispatch(self, custom_env_sync_root, custom_env_sync_path_va
mock_app.works = [mock.MagicMock()]
cloud_runtime = cloud.CloudRuntime(app=mock_app, entrypoint=Path("."))

cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name", "cluster_id")
app = cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name", "cluster_id")
assert app.id == "instance_id"

mock_client.cloud_space_service_get_cloud_space.assert_called_once_with(
project_id="project_id", id="cloudspace_id"
Expand Down

0 comments on commit 02780f2

Please sign in to comment.