Skip to content

Commit

Permalink
wip clean up autoscaler ui
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 15, 2022
1 parent 6745531 commit 065db7e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
14 changes: 5 additions & 9 deletions examples/app_server_with_auto_scaler/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ! pip install torch torchvision
from typing import Any, List
from typing import List

import torch
import torchvision
Expand All @@ -8,16 +8,12 @@
import lightning as L


class RequestModel(BaseModel):
image: str # bytecode


class BatchRequestModel(BaseModel):
inputs: List[RequestModel]
inputs: List[L.app.components.Image]


class BatchResponse(BaseModel):
outputs: List[Any]
outputs: List[L.app.components.Number]


class PyTorchServer(L.app.components.PythonServer):
Expand Down Expand Up @@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int:
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
input_type=L.app.components.Image,
output_type=L.app.components.Number,
timeout_batching=1,
max_batch_size=8,
)
Expand Down
51 changes: 49 additions & 2 deletions src/lightning_app/components/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
async def balance_api(inputs: self._input_type):
return await self.process_request(inputs)

logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")

uvicorn.run(
fastapi_app,
host=self.host,
Expand Down Expand Up @@ -332,6 +334,51 @@ def send_request_to_update_servers(self, servers: List[str]):
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
response.raise_for_status()

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
for k, v in datatype_props.items():
if v["type"] == "string":
out[k] = "data string"
elif v["type"] == "number":
out[k] = 0.0
elif v["type"] == "integer":
out[k] = 0
elif v["type"] == "boolean":
out[k] = False
else:
raise TypeError("Unsupported type")
return out

def configure_layout(self) -> None:
try:
from lightning_api_access import APIAccessFrontend
except ModuleNotFoundError:
logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI")
return

try:
request = self._get_sample_dict_from_datatype(self._input_type)
response = self._get_sample_dict_from_datatype(self._output_type)
except (AttributeError, TypeError):
return

return APIAccessFrontend(
apis=[
{
"name": self.__class__.__name__,
"url": f"{self.url}{self.endpoint}",
"method": "POST",
"request": request,
"response": response,
}
]
)


class AutoScaler(LightningFlow):
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
Expand Down Expand Up @@ -574,5 +621,5 @@ def autoscale(self) -> None:
self._last_autoscale = time.time()

def configure_layout(self):
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
return tabs
layout = self.load_balancer.configure_layout()
return layout if layout else super().configure_layout()

0 comments on commit 065db7e

Please sign in to comment.