diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 132706e539837..b427988b92400 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -12,9 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035)) - - Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047)) +- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018)) ### Changed diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index ca47c36071dae..18208aa316f2d 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -10,7 +10,7 @@ from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript from lightning_app.components.serve.gradio import ServeGradio -from lightning_app.components.serve.python_server import Image, Number, PythonServer +from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.serve import ModelInferenceAPI from lightning_app.components.serve.streamlit import ServeStreamlit from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner @@ -28,6 +28,8 @@ "PythonServer", "Image", "Number", + "Category", + "Text", "MultiNode", "LiteMultiNode", "LightningTrainerScript", diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py index cb46a71bf9ea5..a12cb1b45ee71 100644 --- a/src/lightning_app/components/serve/__init__.py +++ b/src/lightning_app/components/serve/__init__.py @@ -1,5 +1,5 @@ from lightning_app.components.serve.gradio import ServeGradio -from lightning_app.components.serve.python_server import Image, Number, PythonServer +from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.streamlit import ServeStreamlit -__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"] +__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"] diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 40b7e83a3bdca..ee958b30625fd 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -2,9 +2,9 @@ import base64 import os import platform -from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING +import requests import uvicorn from fastapi import FastAPI from lightning_utilities.core.imports import compare_version, module_available @@ -14,6 +14,9 @@ from lightning_app.utilities.app_helpers import Logger from lightning_app.utilities.imports import _is_torch_available, requires +if TYPE_CHECKING: + from lightning_app.frontend.frontend import Frontend + logger = Logger(__name__) # Skip doctests if requirements aren't available @@ -48,18 +51,80 @@ class Image(BaseModel): image: Optional[str] @staticmethod - def _get_sample_data() -> Dict[Any, Any]: - imagepath = Path(__file__).parent / "catimage.png" - with open(imagepath, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return {"image": encoded_string.decode("UTF-8")} + def get_sample_data() -> Dict[Any, Any]: + url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png" + img = requests.get(url).content + img = base64.b64encode(img).decode("UTF-8") + return {"image": img} + + @staticmethod + def request_code_sample(url: str) -> str: + return ( + """import base64 +from pathlib import Path +import requests + +imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png" +img = requests.get(imgurl).content +img = base64.b64encode(img).decode("UTF-8") +response = requests.post('""" + + url + + """', json={ + "image": img +})""" + ) + + @staticmethod + def response_code_sample() -> str: + return """img = response.json()["image"] +img = base64.b64decode(img.encode("utf-8")) +Path("response.png").write_bytes(img) +""" + + +class Category(BaseModel): + category: Optional[int] + + @staticmethod + def get_sample_data() -> Dict[Any, Any]: + return {"prediction": 463} + + @staticmethod + def response_code_sample() -> str: + return """print("Predicted category is: ", response.json()["category"]) +""" + + +class Text(BaseModel): + text: Optional[str] + + @staticmethod + def get_sample_data() -> Dict[Any, Any]: + return {"text": "A portrait of a person looking away from the camera"} + + @staticmethod + def request_code_sample(url: str) -> str: + return ( + """import base64 +from pathlib import Path +import requests + +response = requests.post('""" + + url + + """', json={ + "text": "A portrait of a person looking away from the camera" +}) +""" + ) class Number(BaseModel): + # deprecated + # TODO remove this in favour of Category prediction: Optional[int] @staticmethod - def _get_sample_data() -> Dict[Any, Any]: + def get_sample_data() -> Dict[Any, Any]: return {"prediction": 463} @@ -154,8 +219,8 @@ def predict(self, request: Any) -> Any: @staticmethod def _get_sample_dict_from_datatype(datatype: Any) -> dict: - if hasattr(datatype, "_get_sample_data"): - return datatype._get_sample_data() + if hasattr(datatype, "get_sample_data"): + return datatype.get_sample_data() datatype_props = datatype.schema()["properties"] out: Dict[str, Any] = {} @@ -187,7 +252,15 @@ def predict_fn(request: input_type): # type: ignore fastapi_app.post("/predict", response_model=output_type)(predict_fn) - def configure_layout(self) -> None: + def get_code_sample(self, url: str) -> Optional[str]: + input_type: Any = self.configure_input_type() + output_type: Any = self.configure_output_type() + + if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")): + return None + return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}" + + def configure_layout(self) -> Optional["Frontend"]: try: from lightning_api_access import APIAccessFrontend except ModuleNotFoundError: @@ -203,17 +276,19 @@ def configure_layout(self) -> None: except TypeError: return None - return APIAccessFrontend( - apis=[ - { - "name": class_name, - "url": url, - "method": "POST", - "request": request, - "response": response, - } - ] - ) + frontend_payload = { + "name": class_name, + "url": url, + "method": "POST", + "request": request, + "response": response, + } + + code_sample = self.get_code_sample(url) + if code_sample: + frontend_payload["code_sample"] = code_sample + + return APIAccessFrontend(apis=[frontend_payload]) def run(self, *args: Any, **kwargs: Any) -> Any: """Run method takes care of configuring and setting up a FastAPI server behind the scenes. diff --git a/tests/tests_app/components/serve/test_python_server.py b/tests/tests_app/components/serve/test_python_server.py index 313638e9ec42a..45275af9f87b7 100644 --- a/tests/tests_app/components/serve/test_python_server.py +++ b/tests/tests_app/components/serve/test_python_server.py @@ -32,14 +32,14 @@ def test_python_server_component(): def test_image_sample_data(): - data = Image()._get_sample_data() + data = Image().get_sample_data() assert isinstance(data, dict) assert "image" in data assert len(data["image"]) > 100 def test_number_sample_data(): - data = Number()._get_sample_data() + data = Number().get_sample_data() assert isinstance(data, dict) assert "prediction" in data assert data["prediction"] == 463