Skip to content

Commit

Permalink
[App] Add configure_layout method for works (#15926)
Browse files Browse the repository at this point in the history
* Add `configure_layout` method for works
* Check for api access availability
* Updates from review
* Update CHANGELOG.md
* Apply suggestions from code review

Co-authored-by: Sherin Thomas <[email protected]>
  • Loading branch information
ethanwharris and Sherin Thomas authored Dec 8, 2022
1 parent 73a6dbe commit d5b9c67
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 71 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))

- Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926))


### Changed

Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/components/serve/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ def run(self, *args, **kwargs):
server_port=self.port,
enable_queue=self.enable_queue,
)

def configure_layout(self) -> str:
return self.url
69 changes: 24 additions & 45 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from fastapi import FastAPI
from lightning_utilities.core.imports import module_available
from pydantic import BaseModel
from starlette.staticfiles import StaticFiles

from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
Expand Down Expand Up @@ -222,49 +221,30 @@ def predict_fn(request: input_type): # type: ignore

fastapi_app.post("/predict", response_model=output_type)(predict_fn)

def _attach_frontend(self, fastapi_app: FastAPI) -> None:
from lightning_api_access import APIAccessFrontend

class_name = self.__class__.__name__
url = self._future_url if self._future_url else self.url
if not url:
# if the url is still empty, point it to localhost
url = f"http://127.0.0.1:{self.port}"
url = f"{url}/predict"
datatype_parse_error = False
try:
request = self._get_sample_dict_from_datatype(self.configure_input_type())
except TypeError:
datatype_parse_error = True

try:
response = self._get_sample_dict_from_datatype(self.configure_output_type())
except TypeError:
datatype_parse_error = True

if datatype_parse_error:

@fastapi_app.get("/")
def index() -> str:
return (
"Automatic generation of the UI is only supported for simple, "
"non-nested datatype with types string, integer, float and boolean"
)

return

frontend = APIAccessFrontend(
apis=[
{
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}
]
)
fastapi_app.mount("/", StaticFiles(directory=frontend.serve_dir, html=True), name="static")
def configure_layout(self) -> None:
if module_available("lightning_api_access"):
from lightning_api_access import APIAccessFrontend

class_name = self.__class__.__name__
url = f"{self.url}/predict"

try:
request = self._get_sample_dict_from_datatype(self.configure_input_type())
response = self._get_sample_dict_from_datatype(self.configure_output_type())
except TypeError:
return None

return APIAccessFrontend(
apis=[
{
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}
]
)

def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
Expand All @@ -275,7 +255,6 @@ def run(self, *args: Any, **kwargs: Any) -> Any:

fastapi_app = FastAPI()
self._attach_predict_fn(fastapi_app)
self._attach_frontend(fastapi_app)

logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")
9 changes: 3 additions & 6 deletions src/lightning_app/components/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from starlette.responses import RedirectResponse

from lightning_app.components.serve.types import _DESERIALIZER, _SERIALIZER
from lightning_app.core.work import LightningWork
Expand All @@ -37,10 +36,6 @@ async def run(self, data) -> Any:
return self.serialize(self.predict(self.deserialize(data)))


async def _redirect():
return RedirectResponse("/docs")


class ModelInferenceAPI(LightningWork, abc.ABC):
def __init__(
self,
Expand Down Expand Up @@ -121,7 +116,6 @@ def run(self):
def _populate_app(self, fastapi_service: FastAPI):
self._model = self.build_model()

fastapi_service.get("/")(_redirect)
fastapi_service.post("/predict", response_class=JSONResponse)(
_InferenceCallable(
deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize,
Expand All @@ -134,6 +128,9 @@ def _launch_server(self, fastapi_service: FastAPI):
logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error")

def configure_layout(self) -> str:
return f"{self.url}/docs"


def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/components/serve/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def on_exit(self) -> None:
if self._process is not None:
self._process.kill()

def configure_layout(self) -> str:
return self.url


class _PatchedWork:
"""The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the
Expand Down
6 changes: 4 additions & 2 deletions src/lightning_app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning_app.frontend import Frontend
from lightning_app.storage import Path
from lightning_app.storage.drive import _maybe_create_drive, Drive
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name, is_overridden
from lightning_app.utilities.component import _sanitize_state
from lightning_app.utilities.exceptions import ExitAppException
from lightning_app.utilities.introspection import _is_init_context, _is_run_context
Expand Down Expand Up @@ -777,4 +777,6 @@ def run(self):
self.work.run()

def configure_layout(self):
return [{"name": "Main", "content": self.work}]
if is_overridden("configure_layout", self.work):
return [{"name": "Main", "content": self.work}]
return []
47 changes: 46 additions & 1 deletion src/lightning_app/core/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union

from deepdiff import DeepHash, Delta

Expand Down Expand Up @@ -33,6 +33,9 @@
)
from lightning_app.utilities.proxies import Action, LightningWorkSetAttrProxy, ProxyWorkRun, unwrap, WorkRunExecutor

if TYPE_CHECKING:
from lightning_app.frontend import Frontend


class LightningWork:

Expand Down Expand Up @@ -629,3 +632,45 @@ def apply_flow_delta(self, delta: Delta):
property_object.fset(self, value)
else:
self._default_setattr(name, value)

def configure_layout(self) -> Union[None, str, "Frontend"]:
"""Configure the UI of this LightningWork.
You can either
1. Return a single :class:`~lightning_app.frontend.frontend.Frontend` object to serve a user interface
for this Work.
2. Return a string containing a URL to act as the user interface for this Work.
3. Return ``None`` to indicate that this Work doesn't currently have a user interface.
**Example:** Serve a static directory (with at least a file index.html inside).
.. code-block:: python
from lightning_app.frontend import StaticWebFrontend
class Work(LightningWork):
def configure_layout(self):
return StaticWebFrontend("path/to/folder/to/serve")
**Example:** Arrange the UI of my children in tabs (default UI by Lightning).
.. code-block:: python
class Work(LightningWork):
def configure_layout(self):
return [
dict(name="First Tab", content=self.child0),
dict(name="Second Tab", content=self.child1),
dict(name="Lightning", content="https://lightning.ai"),
]
If you don't implement ``configure_layout``, Lightning will use ``self.url``.
Note:
This hook gets called at the time of app creation and then again as part of the loop. If desired, a
returned URL can depend on the state. This is not the case if the work returns a
:class:`~lightning_app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
in order for the runtime to start the server.
"""
91 changes: 82 additions & 9 deletions src/lightning_app/utilities/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import lightning_app
from lightning_app.frontend.frontend import Frontend
from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable
from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable, is_overridden
from lightning_app.utilities.cloud import is_running_in_cloud


Expand Down Expand Up @@ -45,9 +45,9 @@ def _collect_layout(app: "lightning_app.LightningApp", flow: "lightning_app.Ligh
app.frontends.setdefault(flow.name, "mock")
return flow._layout
elif isinstance(layout, dict):
layout = _collect_content_layout([layout], flow)
layout = _collect_content_layout([layout], app, flow)
elif isinstance(layout, (list, tuple)) and all(isinstance(item, dict) for item in layout):
layout = _collect_content_layout(layout, flow)
layout = _collect_content_layout(layout, app, flow)
else:
lines = _add_comment_to_literal_code(flow.configure_layout, contains="return", comment=" <------- this guy")
m = f"""
Expand Down Expand Up @@ -76,7 +76,9 @@ def configure_layout(self):
return layout


def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFlow") -> List[Dict]:
def _collect_content_layout(
layout: List[Dict], app: "lightning_app.LightningApp", flow: "lightning_app.LightningFlow"
) -> Union[List[Dict], Dict]:
"""Process the layout returned by the ``configure_layout()`` method if the returned format represents an
aggregation of child layouts."""
for entry in layout:
Expand All @@ -102,12 +104,43 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl
entry["content"] = entry["content"].name

elif isinstance(entry["content"], lightning_app.LightningWork):
if entry["content"].url and not entry["content"].url.startswith("/"):
entry["content"] = entry["content"].url
entry["target"] = entry["content"]
else:
work = entry["content"]
work_layout = _collect_work_layout(work)

if work_layout is None:
entry["content"] = ""
entry["target"] = ""
elif isinstance(work_layout, str):
entry["content"] = work_layout
entry["target"] = work_layout
elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
if len(layout) > 1:
lines = _add_comment_to_literal_code(
flow.configure_layout, contains="return", comment=" <------- this guy"
)
m = f"""
The return value of configure_layout() in `{flow.__class__.__name__}` is an
unsupported format:
\n{lines}
The tab containing a `{work.__class__.__name__}` must be the only tab in the
layout of this flow.
(see the docs for `LightningWork.configure_layout`).
"""
raise TypeError(m)

if isinstance(work_layout, Frontend):
# If the work returned a frontend, treat it as belonging to the flow.
# NOTE: This could evolve in the future to run the Frontend directly in the work machine.
frontend = work_layout
frontend.flow = flow
elif isinstance(work_layout, _MagicMockJsonSerializable):
# The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI.
frontend = "mock"

app.frontends.setdefault(flow.name, frontend)
return flow._layout

elif isinstance(entry["content"], _MagicMockJsonSerializable):
# The import was mocked, we just record dummy content so that `is_headless` knows there is a UI
entry["content"] = "mock"
Expand All @@ -126,3 +159,43 @@ def configure_layout(self):
"""
raise ValueError(m)
return layout


def _collect_work_layout(work: "lightning_app.LightningWork") -> Union[None, str, Frontend, _MagicMockJsonSerializable]:
"""Check if ``configure_layout`` is overridden on the given work and return the work layout (either a string, a
``Frontend`` object, or an instance of a mocked import).
Args:
work: The work to collect the layout for.
Raises:
TypeError: If the value returned by ``configure_layout`` is not of a supported format.
"""
if is_overridden("configure_layout", work):
work_layout = work.configure_layout()
else:
work_layout = work.url

if work_layout is None:
return None
elif isinstance(work_layout, str):
url = work_layout
# The URL isn't fully defined yet. Looks something like ``self.work.url + /something``.
if url and not url.startswith("/"):
return url
return ""
elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
return work_layout
else:
m = f"""
The value returned by `{work.__class__.__name__}.configure_layout()` is of an unsupported type.
{repr(work_layout)}
Return a `Frontend` or a URL string, for example:
class {work.__class__.__name__}(LightningWork):
def configure_layout(self):
return MyFrontend() OR 'http://some/url'
"""
raise TypeError(m)
Loading

0 comments on commit d5b9c67

Please sign in to comment.