Skip to content

Commit

Permalink
Sample datatype for Serve Component (#15623)
Browse files Browse the repository at this point in the history
* introducing serve component

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up tests

* clean up tests

* doctest

* mypy

* structure-fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* cleanup

* test fix

* addition

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* requirements

* getting future url

* url for local

* sample data typeg

* changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* prediction

* updates

* updates

* manifest

* fix type error

* fixed test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rick Izzo <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Nov 10, 2022
1 parent 61d3253 commit 136a090
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 6 deletions.
42 changes: 42 additions & 0 deletions examples/app_server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# !pip install torchvision pydantic
import base64
import io

import torch
import torchvision
from PIL import Image
from pydantic import BaseModel

import lightning as L
from lightning.app.components.serve import Image as InputImage
from lightning.app.components.serve import PythonServer


class PyTorchServer(PythonServer):
def setup(self):
self._model = torchvision.models.resnet18(pretrained=True)
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)

def predict(self, request):
image = base64.b64decode(request.image.encode("utf-8"))
image = Image.open(io.BytesIO(image))
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = transforms(image)
image = image.to(self._device)
prediction = self._model(image.unsqueeze(0))
return {"prediction": prediction.argmax().item()}


class OutputData(BaseModel):
prediction: int


component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
app = L.LightningApp(component)
1 change: 1 addition & 0 deletions src/lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
"recursive-include requirements *.txt",
"recursive-include src/lightning/app/ui *",
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
# fixme: this is strange, this shall work with setup find package - include
"prune src/lightning_app",
"prune src/lightning_lite",
Expand Down
1 change: 1 addition & 0 deletions src/lightning_app/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
"recursive-exclude src *.md" + os.linesep,
"recursive-exclude requirements *.txt" + os.linesep,
"recursive-include src/lightning_app *.md" + os.linesep,
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
"recursive-include requirements/app *.txt" + os.linesep,
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
]
Expand Down
4 changes: 3 additions & 1 deletion src/lightning_app/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
Expand All @@ -24,6 +24,8 @@
"ServeStreamlit",
"ModelInferenceAPI",
"PythonServer",
"Image",
"Number",
"MultiNode",
"LiteMultiNode",
"LightningTrainingComponent",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/serve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
Binary file added src/lightning_app/components/serve/catimage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 32 additions & 2 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
from typing import Any, Dict
import base64
from pathlib import Path
from typing import Any, Dict, Optional

import uvicorn
from fastapi import FastAPI
Expand All @@ -12,6 +14,12 @@
logger = Logger(__name__)


def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")


class _DefaultInputData(BaseModel):
payload: str

Expand All @@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
prediction: str


class Image(BaseModel):
image: Optional[str]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}


class Number(BaseModel):
prediction: Optional[int]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}


class PythonServer(LightningWork, abc.ABC):
def __init__( # type: ignore
self,
Expand Down Expand Up @@ -110,6 +137,9 @@ def predict(self, request: Any) -> Any:

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
for k, v in datatype_props.items():
Expand Down Expand Up @@ -141,7 +171,7 @@ def _attach_frontend(self, fastapi_app: FastAPI) -> None:
url = self._future_url if self._future_url else self.url
if not url:
# if the url is still empty, point it to localhost
url = f"http://127.0.0.1{self.port}"
url = f"http://127.0.0.1:{self.port}"
url = f"{url}/predict"
datatype_parse_error = False
try:
Expand Down
16 changes: 15 additions & 1 deletion tests/tests_app/components/serve/test_python_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing as mp

from lightning_app.components import PythonServer
from lightning_app.components import Image, Number, PythonServer
from lightning_app.utilities.network import _configure_session, find_free_network_port


Expand Down Expand Up @@ -29,3 +29,17 @@ def test_python_server_component():
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate()
assert res.json()["prediction"] == "test"


def test_image_sample_data():
data = Image()._get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100


def test_number_sample_data():
data = Number()._get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463

0 comments on commit 136a090

Please sign in to comment.