diff --git a/docs/source-app/api_references.rst b/docs/source-app/api_references.rst index 931a9864d261f..9bb5874b533e4 100644 --- a/docs/source-app/api_references.rst +++ b/docs/source-app/api_references.rst @@ -46,6 +46,7 @@ ___________________ ~multi_node.pytorch_spawn.PyTorchSpawnMultiNode ~multi_node.trainer.LightningTrainerMultiNode ~serve.auto_scaler.AutoScaler + ~serve.auto_scaler.ColdStartProxy ---- diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 5d000ce1ef625..2d46e72f74085 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -4,6 +4,7 @@ pytest==7.2.0 pytest-timeout==2.1.0 pytest-cov==4.0.0 pytest-doctestplus>=0.9.0 +pytest-asyncio==0.20.3 playwright==1.28.0 httpx trio<0.22.0 diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e92e80c71deb0..8fd0b855440b8 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095)) +- Added `ColdStartProxy` to the AutoScaler ([#16094](https://github.com/Lightning-AI/lightning/pull/16094)) + ### Changed @@ -62,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092) +- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110)) + - Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114)) diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index 5fd8af6b055de..0275596288ff0 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -8,7 +8,7 @@ ) from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript -from lightning_app.components.serve.auto_scaler import AutoScaler +from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.serve import ModelInferenceAPI @@ -17,6 +17,7 @@ __all__ = [ "AutoScaler", + "ColdStartProxy", "DatabaseClient", "Database", "PopenPythonScript", diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py index ac02e69c4f2ab..39dafe2f7ff1b 100644 --- a/src/lightning_app/components/serve/__init__.py +++ b/src/lightning_app/components/serve/__init__.py @@ -1,6 +1,16 @@ -from lightning_app.components.serve.auto_scaler import AutoScaler +from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.streamlit import ServeStreamlit -__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"] +__all__ = [ + "ServeGradio", + "ServeStreamlit", + "PythonServer", + "Image", + "Number", + "Category", + "Text", + "AutoScaler", + "ColdStartProxy", +] diff --git a/src/lightning_app/components/serve/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py index 2493f63048e60..a8af51ca79913 100644 --- a/src/lightning_app/components/serve/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -6,7 +6,7 @@ import uuid from base64 import b64encode from itertools import cycle -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import requests import uvicorn @@ -32,7 +32,53 @@ logger = Logger(__name__) -def _raise_granular_exception(exception: Exception) -> None: +class ColdStartProxy: + """ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold + starting. This is useful with services that gets realtime requests but startup time for workers is high. + + If the request body is same and the method is POST for the proxy service, + then the default implementation of `handle_request` can be used. In that case + initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request` + + Args: + proxy_url (str): The url of the proxy service + """ + + def __init__(self, proxy_url): + self.proxy_url = proxy_url + self.proxy_timeout = 50 + # checking `asyncio.iscoroutinefunction` instead of `inspect.iscoroutinefunction` + # because AsyncMock in the tests requres the former to pass + if not asyncio.iscoroutinefunction(self.handle_request): + raise TypeError("handle_request must be an `async` function") + + async def handle_request(self, request: BaseModel) -> Any: + """This method is called when the request is received while the work is cold starting. The default + implementation of this method is to forward the request body to the proxy service with POST method but the + user can override this method to handle the request in any way. + + Args: + request (BaseModel): The request body, a pydantic model that is being + forwarded by load balancer which is a FastAPI service + """ + try: + async with aiohttp.ClientSession() as session: + headers = { + "accept": "application/json", + "Content-Type": "application/json", + } + async with session.post( + self.proxy_url, + json=request.dict(), + timeout=self.proxy_timeout, + headers=headers, + ) as response: + return await response.json() + except Exception as ex: + raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}") + + +def _maybe_raise_granular_exception(exception: Exception) -> None: """Handle an exception from hitting the model servers.""" if not isinstance(exception, Exception): return @@ -116,6 +162,8 @@ class _LoadBalancer(LightningWork): requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached. timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received. timeout_inference_request: The number of seconds to wait for inference. + api_name: The name to be displayed on the UI. Normally, it is the name of the work class + cold_start_proxy: The proxy service to use while the work is cold starting. **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc. """ @@ -130,7 +178,8 @@ def __init__( timeout_batching: float = 1, timeout_keep_alive: int = 60, timeout_inference_request: int = 60, - work_name: Optional[str] = "API", # used for displaying the name in the UI + api_name: Optional[str] = "API", # used for displaying the name in the UI + cold_start_proxy: Union[ColdStartProxy, str, None] = None, **kwargs: Any, ) -> None: super().__init__(cloud_compute=CloudCompute("default"), **kwargs) @@ -138,37 +187,54 @@ def __init__( self._output_type = output_type self._timeout_keep_alive = timeout_keep_alive self._timeout_inference_request = timeout_inference_request - self.servers = [] + self._servers = [] self.max_batch_size = max_batch_size self.timeout_batching = timeout_batching self._iter = None self._batch = [] self._responses = {} # {request_id: response} self._last_batch_sent = 0 - self._work_name = work_name + self._server_status = {} + self._api_name = api_name if not endpoint.startswith("/"): endpoint = "/" + endpoint self.endpoint = endpoint + self._fastapi_app = None + + self._cold_start_proxy = None + if cold_start_proxy: + if isinstance(cold_start_proxy, str): + self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy) + elif isinstance(cold_start_proxy, ColdStartProxy): + self._cold_start_proxy = cold_start_proxy + else: + raise ValueError("cold_start_proxy must be of type ColdStartProxy or str") - async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]): - server = next(self._iter) # round-robin + async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str): request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] batch_request_data = _BatchRequestModel(inputs=request_data) try: + self._server_status[server_url] = False async with aiohttp.ClientSession() as session: headers = { "accept": "application/json", "Content-Type": "application/json", } async with session.post( - f"{server}{self.endpoint}", + f"{server_url}{self.endpoint}", json=batch_request_data.dict(), timeout=self._timeout_inference_request, headers=headers, ) as response: + # resetting the server status so other requests can be + # scheduled on this node + if server_url in self._server_status: + # TODO - if the server returns an error, track that so + # we don't send more requests to it + self._server_status[server_url] = True if response.status == 408: raise HTTPException(408, "Request timed out") response.raise_for_status() @@ -181,48 +247,87 @@ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]): except Exception as ex: result = {request[0]: ex for request in batch} self._responses.update(result) + finally: + self._server_status[server_url] = True + + def _find_free_server(self) -> Optional[str]: + existing = set(self._server_status.keys()) + for server in existing: + status = self._server_status.get(server, None) + if status is None: + logger.error("Server is not found in the status list. This should not happen.") + if status: + return server async def consumer(self): + """The consumer process that continuously checks for new requests and sends them to the API. + + Two instances of this function should not be running with shared `_state_server` as that would create race + conditions + """ + self._last_batch_sent = time.time() while True: await asyncio.sleep(0.05) - batch = self._batch[: self.max_batch_size] - while batch and ( - (len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching) - ): - asyncio.create_task(self.send_batch(batch)) - - self._batch = self._batch[self.max_batch_size :] - batch = self._batch[: self.max_batch_size] + is_batch_ready = len(batch) == self.max_batch_size + is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching + server_url = self._find_free_server() + # setting the server status to be busy! This will be reset by + # the send_batch function after the server responds + if server_url is None: + continue + if batch and (is_batch_ready or is_batch_timeout): + # find server with capacity + asyncio.create_task(self.send_batch(batch, server_url)) + # resetting the batch array, TODO - not locking the array + self._batch = self._batch[len(batch) :] self._last_batch_sent = time.time() - async def process_request(self, data: BaseModel): - if not self.servers: + async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex): + if not self._servers and not self._cold_start_proxy: raise HTTPException(500, "None of the workers are healthy!") - request_id = uuid.uuid4().hex - request: Tuple = (request_id, data) - self._batch.append(request) + # if no servers are available, proxy the request to cold start proxy handler + if not self._servers and self._cold_start_proxy: + return await self._cold_start_proxy.handle_request(data) + + # if out of capacity, proxy the request to cold start proxy handler + if not self._has_processing_capacity() and self._cold_start_proxy: + return await self._cold_start_proxy.handle_request(data) + # if we have capacity, process the request + self._batch.append((request_id, data)) while True: await asyncio.sleep(0.05) - if request_id in self._responses: result = self._responses[request_id] del self._responses[request_id] - _raise_granular_exception(result) + _maybe_raise_granular_exception(result) return result + def _has_processing_capacity(self): + """This function checks if we have processing capacity for one more request or not. + + Depends on the value from here, we decide whether we should proxy the request or not + """ + if not self._fastapi_app: + return False + active_server_count = len(self._servers) + max_processable = self.max_batch_size * active_server_count + current_req_count = self._fastapi_app.num_current_requests + return current_req_count < max_processable + def run(self): - logger.info(f"servers: {self.servers}") + logger.info(f"servers: {self._servers}") lock = asyncio.Lock() - self._iter = cycle(self.servers) + self._iter = cycle(self._servers) self._last_batch_sent = time.time() fastapi_app = _create_fastapi("Load Balancer") security = HTTPBasic() fastapi_app.SEND_TASK = None + self._fastapi_app = fastapi_app input_type = self._input_type @@ -269,8 +374,8 @@ def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(se @fastapi_app.get("/system/info", response_model=_SysInfo) async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)): return _SysInfo( - num_workers=len(self.servers), - servers=self.servers, + num_workers=len(self._servers), + servers=self._servers, num_requests=fastapi_app.num_current_requests, processing_time=fastapi_app.last_processing_time, global_request_count=fastapi_app.global_request_count, @@ -279,8 +384,20 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)) @fastapi_app.put("/system/update-servers") async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)): async with lock: - self.servers = servers - self._iter = cycle(self.servers) + self._servers = servers + self._iter = cycle(self._servers) + updated_servers = set() + # do not try to loop over the dict keys as the dict might change from other places + existing_servers = list(self._server_status.keys()) + for server in servers: + updated_servers.add(server) + if server not in existing_servers: + self._server_status[server] = True + logger.info(f"Registering server {server}", self._server_status) + for existing in existing_servers: + if existing not in updated_servers: + logger.info(f"De-Registering server {existing}", self._server_status) + del self._server_status[existing] @fastapi_app.post(self.endpoint, response_model=self._output_type) async def balance_api(inputs: input_type): @@ -308,7 +425,7 @@ def update_servers(self, server_works: List[LightningWork]): AutoScaler uses this method to increase/decrease the number of works. """ - old_servers = set(self.servers) + old_servers = set(self._servers) server_urls: List[str] = [server.url for server in server_works if server.url] new_servers = set(server_urls) @@ -384,7 +501,7 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82 else: url = f"http://localhost:{self.port}{self.endpoint}" - frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None} + frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None} code_samples = self.get_code_sample(url) if code_samples: frontend_objects["code_samples"] = code_samples @@ -416,6 +533,7 @@ class AutoScaler(LightningFlow): timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process. input_type: Input type. output_type: Output type. + cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up. .. testcode:: @@ -477,6 +595,7 @@ def __init__( endpoint: str = "api/predict", input_type: Type[BaseModel] = Dict, output_type: Type[BaseModel] = Dict, + cold_start_proxy: Union[ColdStartProxy, str, None] = None, *work_args: Any, **work_kwargs: Any, ) -> None: @@ -511,7 +630,8 @@ def __init__( timeout_batching=timeout_batching, cache_calls=True, parallel=True, - work_name=self._work_cls.__name__, + api_name=self._work_cls.__name__, + cold_start_proxy=cold_start_proxy, ) for _ in range(min_replicas): work = self.create_work() diff --git a/tests/tests_app/components/serve/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py index c3cfa99c9d69b..e53c7890696a4 100644 --- a/tests/tests_app/components/serve/test_auto_scaler.py +++ b/tests/tests_app/components/serve/test_auto_scaler.py @@ -1,11 +1,14 @@ import time +import uuid from unittest import mock from unittest.mock import patch import pytest +from fastapi import HTTPException from lightning_app import CloudCompute, LightningWork -from lightning_app.components import AutoScaler, Text +from lightning_app.components import AutoScaler, ColdStartProxy, Text +from lightning_app.components.serve.auto_scaler import _LoadBalancer class EmptyWork(LightningWork): @@ -129,7 +132,7 @@ def test_create_work_cloud_compute_cloned(): @patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock()) def test_API_ACCESS_ENDPOINT_creation(): auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text) - assert auto_scaler.load_balancer._work_name == "EmptyWork" + assert auto_scaler.load_balancer._api_name == "EmptyWork" auto_scaler.load_balancer.run() fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static") @@ -173,3 +176,55 @@ def test_autoscaler_scale_down(monkeypatch): auto_scaler.autoscale() auto_scaler.scale.assert_called_once() auto_scaler.remove_work.assert_called_once() + + +class TestLoadBalancerProcessRequest: + @pytest.mark.asyncio + async def test_workers_not_ready_with_cold_start_proxy(self, monkeypatch): + monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock()) + load_balancer = _LoadBalancer( + input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url") + ) + req_id = uuid.uuid4().hex + await load_balancer.process_request("test", req_id) + load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test") + + @pytest.mark.asyncio + async def test_workers_not_ready_without_cold_start_proxy(self, monkeypatch): + load_balancer = _LoadBalancer( + input_type=Text, + output_type=Text, + endpoint="/predict", + ) + req_id = uuid.uuid4().hex + # populating the responses so the while loop exists + load_balancer._responses = {req_id: "Dummy"} + with pytest.raises(HTTPException): + await load_balancer.process_request("test", req_id) + + @pytest.mark.asyncio + async def test_workers_have_no_capacity_with_cold_start_proxy(self, monkeypatch): + monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock()) + load_balancer = _LoadBalancer( + input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url") + ) + load_balancer._fastapi_app = mock.MagicMock() + load_balancer._fastapi_app.num_current_requests = 1000 + load_balancer._servers.append(mock.MagicMock()) + req_id = uuid.uuid4().hex + await load_balancer.process_request("test", req_id) + load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test") + + @pytest.mark.asyncio + async def test_workers_are_free(self): + load_balancer = _LoadBalancer( + input_type=Text, + output_type=Text, + endpoint="/predict", + ) + load_balancer._servers.append(mock.MagicMock()) + req_id = uuid.uuid4().hex + # populating the responses so the while loop exists + load_balancer._responses = {req_id: "Dummy"} + await load_balancer.process_request("test", req_id) + assert load_balancer._batch == [(req_id, "test")]