Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Cold start proxy in autoscaler #16094

Merged
merged 32 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4bafc59
wip clean up autoscaler ui
akihironitta Dec 15, 2022
1cefa05
Revert "wip clean up autoscaler ui"
akihironitta Dec 15, 2022
1e69092
Apply sherin's suggestion
akihironitta Dec 15, 2022
f9406cc
update example
akihironitta Dec 15, 2022
694627f
print endpoint in the log
akihironitta Dec 15, 2022
96b77ea
Fix import
akihironitta Dec 15, 2022
44cbec2
revert irrelevant change
akihironitta Dec 15, 2022
5d8af44
Merge branch 'master' into feat/autoscaler-ui
Dec 16, 2022
82fef89
Update src/lightning_app/components/auto_scaler.py
Dec 16, 2022
d8f4778
clean up
Dec 16, 2022
c7443b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2022
26f5f4b
test rename
Dec 16, 2022
1e7199c
Merge branch 'feat/autoscaler-ui' of github.com:Lightning-AI/lightnin…
Dec 16, 2022
4f3365c
Changelog
Dec 16, 2022
8501bd4
cold start proxy
Dec 16, 2022
0e2cee0
cold-start-proxy
Dec 16, 2022
7254e99
master
Dec 19, 2022
85e0205
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2022
5d04a21
Merge branch 'master' into cold-start-proxy
Dec 19, 2022
08643b2
Merge branch 'master' into cold-start-proxy
Dec 19, 2022
d62dd8b
merge
Dec 19, 2022
8dad987
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2022
901402a
Merge branch 'master' into cold-start-proxy
Dec 19, 2022
4d2596d
Merge branch 'master' into cold-start-proxy
Dec 20, 2022
eda6be5
docs and tests
Dec 20, 2022
b05c9aa
Merge branch 'master' into cold-start-proxy
Dec 20, 2022
5d727e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2022
21d325d
[App] Fixing auto batching in Autoscaler (#16110)
Dec 20, 2022
0af707f
Update src/lightning_app/components/serve/auto_scaler.py
Dec 20, 2022
77db466
changelog
Dec 20, 2022
6167943
better-doc
Dec 20, 2022
c1f7d70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-app/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ___________________
~multi_node.pytorch_spawn.PyTorchSpawnMultiNode
~multi_node.trainer.LightningTrainerMultiNode
~serve.auto_scaler.AutoScaler
~serve.auto_scaler.ColdStartProxy

----

Expand Down
1 change: 1 addition & 0 deletions requirements/app/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytest==7.2.0
pytest-timeout==2.1.0
pytest-cov==4.0.0
pytest-doctestplus>=0.9.0
pytest-asyncio==0.20.3
playwright==1.28.0
httpx
trio<0.22.0
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
Expand All @@ -17,6 +17,7 @@

__all__ = [
"AutoScaler",
"ColdStartProxy",
"DatabaseClient",
"Database",
"PopenPythonScript",
Expand Down
14 changes: 12 additions & 2 deletions src/lightning_app/components/serve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"]
__all__ = [
"ServeGradio",
"ServeStreamlit",
"PythonServer",
"Image",
"Number",
"Category",
"Text",
"AutoScaler",
"ColdStartProxy",
]
123 changes: 102 additions & 21 deletions src/lightning_app/components/serve/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from base64 import b64encode
from itertools import cycle
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import requests
import uvicorn
Expand All @@ -32,7 +32,53 @@
logger = Logger(__name__)


def _raise_granular_exception(exception: Exception) -> None:
class ColdStartProxy:
"""ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold
starting. This is useful with services that gets realtime requests but startup time for workers is high.

If the request body is same and the method is POST for the proxy service,
then the default implementation of `handle_request` can be used. In that case
initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request`

Args:
proxy_url (str): The url of the proxy service
"""

def __init__(self, proxy_url):
self.proxy_url = proxy_url
self.proxy_timeout = 50
# checking `asyncio.iscoroutinefunction` instead of `inspect.iscoroutinefunction`
# because AsyncMock in the tests requres the former to pass
if not asyncio.iscoroutinefunction(self.handle_request):
raise TypeError("handle_request must be an `async` function")

async def handle_request(self, request: BaseModel) -> Any:
"""This method is called when the request is received while the work is cold starting. The default
implementation of this method is to forward the request body to the proxy service with POST method but the
user can override this method to handle the request in any way.

Args:
request (BaseModel): The request body, a pydantic model that is being
forwarded by load balancer which is a FastAPI service
"""
try:
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
self.proxy_url,
json=request.dict(),
timeout=self.proxy_timeout,
headers=headers,
) as response:
return await response.json()
except Exception as ex:
raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}")


def _maybe_raise_granular_exception(exception: Exception) -> None:
"""Handle an exception from hitting the model servers."""
if not isinstance(exception, Exception):
return
Expand Down Expand Up @@ -116,6 +162,8 @@ class _LoadBalancer(LightningWork):
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
timeout_inference_request: The number of seconds to wait for inference.
api_name: The name to be displayed on the UI. Normally, it is the name of the work class
cold_start_proxy: The proxy service to use while the work is cold starting.
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""

Expand All @@ -130,28 +178,40 @@ def __init__(
timeout_batching: float = 1,
timeout_keep_alive: int = 60,
timeout_inference_request: int = 60,
work_name: Optional[str] = "API", # used for displaying the name in the UI
api_name: Optional[str] = "API", # used for displaying the name in the UI
cold_start_proxy: Union[ColdStartProxy, str, None] = None,
**kwargs: Any,
) -> None:
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
self._input_type = input_type
self._output_type = output_type
self._timeout_keep_alive = timeout_keep_alive
self._timeout_inference_request = timeout_inference_request
self.servers = []
self._servers = []
self.max_batch_size = max_batch_size
self.timeout_batching = timeout_batching
self._iter = None
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
self._work_name = work_name
self._api_name = api_name

if not endpoint.startswith("/"):
endpoint = "/" + endpoint

self.endpoint = endpoint

self._fastapi_app = None

self._cold_start_proxy = None
if cold_start_proxy:
if isinstance(cold_start_proxy, str):
self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy)
elif isinstance(cold_start_proxy, ColdStartProxy):
self._cold_start_proxy = cold_start_proxy
else:
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")

async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
server = next(self._iter) # round-robin
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
Expand Down Expand Up @@ -196,33 +256,51 @@ async def consumer(self):
batch = self._batch[: self.max_batch_size]
self._last_batch_sent = time.time()

async def process_request(self, data: BaseModel):
if not self.servers:
async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex):
if not self._servers and not self._cold_start_proxy:
raise HTTPException(500, "None of the workers are healthy!")

request_id = uuid.uuid4().hex
request: Tuple = (request_id, data)
self._batch.append(request)
# if no servers are available, proxy the request to cold start proxy handler
if not self._servers and self._cold_start_proxy:
return await self._cold_start_proxy.handle_request(data)

# if out of capacity, proxy the request to cold start proxy handler
if not self._has_processing_capacity() and self._cold_start_proxy:
return await self._cold_start_proxy.handle_request(data)

# if we have capacity, process the request
self._batch.append((request_id, data))
while True:
await asyncio.sleep(0.05)

if request_id in self._responses:
result = self._responses[request_id]
del self._responses[request_id]
_raise_granular_exception(result)
_maybe_raise_granular_exception(result)
return result

def _has_processing_capacity(self):
"""this function checks if currently have processing capacity for one more request or not.
hhsecond marked this conversation as resolved.
Show resolved Hide resolved

Depends on the value from here, we decide whether we should proxy the request or not
"""
if not self._fastapi_app:
return False
active_server_count = len(self._servers)
max_processable = self.max_batch_size * active_server_count
current_req_count = self._fastapi_app.num_current_requests
return current_req_count < max_processable

def run(self):
logger.info(f"servers: {self.servers}")
logger.info(f"servers: {self._servers}")
lock = asyncio.Lock()

self._iter = cycle(self.servers)
self._iter = cycle(self._servers)
self._last_batch_sent = time.time()

fastapi_app = _create_fastapi("Load Balancer")
security = HTTPBasic()
fastapi_app.SEND_TASK = None
self._fastapi_app = fastapi_app

input_type = self._input_type

Expand Down Expand Up @@ -269,8 +347,8 @@ def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(se
@fastapi_app.get("/system/info", response_model=_SysInfo)
async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)):
return _SysInfo(
num_workers=len(self.servers),
servers=self.servers,
num_workers=len(self._servers),
servers=self._servers,
num_requests=fastapi_app.num_current_requests,
processing_time=fastapi_app.last_processing_time,
global_request_count=fastapi_app.global_request_count,
Expand All @@ -279,8 +357,8 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))
@fastapi_app.put("/system/update-servers")
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
self.servers = servers
self._iter = cycle(self.servers)
self._servers = servers
self._iter = cycle(self._servers)

@fastapi_app.post(self.endpoint, response_model=self._output_type)
async def balance_api(inputs: input_type):
Expand Down Expand Up @@ -308,7 +386,7 @@ def update_servers(self, server_works: List[LightningWork]):

AutoScaler uses this method to increase/decrease the number of works.
"""
old_servers = set(self.servers)
old_servers = set(self._servers)
server_urls: List[str] = [server.url for server in server_works if server.url]
new_servers = set(server_urls)

Expand Down Expand Up @@ -384,7 +462,7 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82
else:
url = f"http://localhost:{self.port}{self.endpoint}"

frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None}
frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None}
code_samples = self.get_code_sample(url)
if code_samples:
frontend_objects["code_samples"] = code_samples
Expand Down Expand Up @@ -416,6 +494,7 @@ class AutoScaler(LightningFlow):
timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
input_type: Input type.
output_type: Output type.
cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up.

.. testcode::

Expand Down Expand Up @@ -477,6 +556,7 @@ def __init__(
endpoint: str = "api/predict",
input_type: Type[BaseModel] = Dict,
output_type: Type[BaseModel] = Dict,
cold_start_proxy: Union[ColdStartProxy, str, None] = None,
*work_args: Any,
**work_kwargs: Any,
) -> None:
Expand Down Expand Up @@ -511,7 +591,8 @@ def __init__(
timeout_batching=timeout_batching,
cache_calls=True,
parallel=True,
work_name=self._work_cls.__name__,
api_name=self._work_cls.__name__,
cold_start_proxy=cold_start_proxy,
)
for _ in range(min_replicas):
work = self.create_work()
Expand Down
59 changes: 57 additions & 2 deletions tests/tests_app/components/serve/test_auto_scaler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time
import uuid
from unittest import mock
from unittest.mock import patch

import pytest
from fastapi import HTTPException

from lightning_app import CloudCompute, LightningWork
from lightning_app.components import AutoScaler, Text
from lightning_app.components import AutoScaler, ColdStartProxy, Text
from lightning_app.components.serve.auto_scaler import _LoadBalancer


class EmptyWork(LightningWork):
Expand Down Expand Up @@ -129,7 +132,7 @@ def test_create_work_cloud_compute_cloned():
@patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock())
def test_API_ACCESS_ENDPOINT_creation():
auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text)
assert auto_scaler.load_balancer._work_name == "EmptyWork"
assert auto_scaler.load_balancer._api_name == "EmptyWork"

auto_scaler.load_balancer.run()
fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")
Expand Down Expand Up @@ -173,3 +176,55 @@ def test_autoscaler_scale_down(monkeypatch):
auto_scaler.autoscale()
auto_scaler.scale.assert_called_once()
auto_scaler.remove_work.assert_called_once()


class TestLoadBalancerProcessRequest:
hhsecond marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.asyncio
async def test_workers_not_ready_with_cold_start_proxy(self, monkeypatch):
monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
load_balancer = _LoadBalancer(
input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
)
req_id = uuid.uuid4().hex
await load_balancer.process_request("test", req_id)
load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")

@pytest.mark.asyncio
async def test_workers_not_ready_without_cold_start_proxy(self, monkeypatch):
load_balancer = _LoadBalancer(
input_type=Text,
output_type=Text,
endpoint="/predict",
)
req_id = uuid.uuid4().hex
# populating the responses so the while loop exists
load_balancer._responses = {req_id: "Dummy"}
with pytest.raises(HTTPException):
await load_balancer.process_request("test", req_id)

@pytest.mark.asyncio
async def test_workers_have_no_capacity_with_cold_start_proxy(self, monkeypatch):
monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
load_balancer = _LoadBalancer(
input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
)
load_balancer._fastapi_app = mock.MagicMock()
load_balancer._fastapi_app.num_current_requests = 1000
load_balancer._servers.append(mock.MagicMock())
req_id = uuid.uuid4().hex
await load_balancer.process_request("test", req_id)
load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")

@pytest.mark.asyncio
async def test_workers_are_free(self):
load_balancer = _LoadBalancer(
input_type=Text,
output_type=Text,
endpoint="/predict",
)
load_balancer._servers.append(mock.MagicMock())
req_id = uuid.uuid4().hex
# populating the responses so the while loop exists
load_balancer._responses = {req_id: "Dummy"}
await load_balancer.process_request("test", req_id)
assert load_balancer._batch == [(req_id, "test")]