Skip to content

Commit

Permalink
[App] Serve datatypes with better client code (#16018)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherin Thomas authored Dec 16, 2022
1 parent 3d509f6 commit 23013be
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))


- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))

- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))

### Changed

Expand Down
4 changes: 3 additions & 1 deletion src/lightning_app/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
Expand All @@ -28,6 +28,8 @@
"PythonServer",
"Image",
"Number",
"Category",
"Text",
"MultiNode",
"LiteMultiNode",
"LightningTrainerScript",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/serve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"]
119 changes: 97 additions & 22 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import base64
import os
import platform
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING

import requests
import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import compare_version, module_available
Expand All @@ -14,6 +14,9 @@
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires

if TYPE_CHECKING:
from lightning_app.frontend.frontend import Frontend

logger = Logger(__name__)

# Skip doctests if requirements aren't available
Expand Down Expand Up @@ -48,18 +51,80 @@ class Image(BaseModel):
image: Optional[str]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
def get_sample_data() -> Dict[Any, Any]:
url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
img = requests.get(url).content
img = base64.b64encode(img).decode("UTF-8")
return {"image": img}

@staticmethod
def request_code_sample(url: str) -> str:
return (
"""import base64
from pathlib import Path
import requests
imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
img = requests.get(imgurl).content
img = base64.b64encode(img).decode("UTF-8")
response = requests.post('"""
+ url
+ """', json={
"image": img
})"""
)

@staticmethod
def response_code_sample() -> str:
return """img = response.json()["image"]
img = base64.b64decode(img.encode("utf-8"))
Path("response.png").write_bytes(img)
"""


class Category(BaseModel):
category: Optional[int]

@staticmethod
def get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}

@staticmethod
def response_code_sample() -> str:
return """print("Predicted category is: ", response.json()["category"])
"""


class Text(BaseModel):
text: Optional[str]

@staticmethod
def get_sample_data() -> Dict[Any, Any]:
return {"text": "A portrait of a person looking away from the camera"}

@staticmethod
def request_code_sample(url: str) -> str:
return (
"""import base64
from pathlib import Path
import requests
response = requests.post('"""
+ url
+ """', json={
"text": "A portrait of a person looking away from the camera"
})
"""
)


class Number(BaseModel):
# deprecated
# TODO remove this in favour of Category
prediction: Optional[int]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
def get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}


Expand Down Expand Up @@ -154,8 +219,8 @@ def predict(self, request: Any) -> Any:

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()
if hasattr(datatype, "get_sample_data"):
return datatype.get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
Expand Down Expand Up @@ -187,7 +252,15 @@ def predict_fn(request: input_type): # type: ignore

fastapi_app.post("/predict", response_model=output_type)(predict_fn)

def configure_layout(self) -> None:
def get_code_sample(self, url: str) -> Optional[str]:
input_type: Any = self.configure_input_type()
output_type: Any = self.configure_output_type()

if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
return None
return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"

def configure_layout(self) -> Optional["Frontend"]:
try:
from lightning_api_access import APIAccessFrontend
except ModuleNotFoundError:
Expand All @@ -203,17 +276,19 @@ def configure_layout(self) -> None:
except TypeError:
return None

return APIAccessFrontend(
apis=[
{
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}
]
)
frontend_payload = {
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}

code_sample = self.get_code_sample(url)
if code_sample:
frontend_payload["code_sample"] = code_sample

return APIAccessFrontend(apis=[frontend_payload])

def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_app/components/serve/test_python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_python_server_component():


def test_image_sample_data():
data = Image()._get_sample_data()
data = Image().get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100


def test_number_sample_data():
data = Number()._get_sample_data()
data = Number().get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463

0 comments on commit 23013be

Please sign in to comment.