diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index d3f4ce5bd4724..eb0262ade8132 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where apps that had previously been deleted could not be run again from the CLI ([#16082](https://github.com/Lightning-AI/lightning/pull/16082)) +- Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092) + ## [1.8.4] - 2022-12-08 diff --git a/src/lightning_app/components/serve/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py index 6027249de850f..6b25c5fce3860 100644 --- a/src/lightning_app/components/serve/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -224,6 +224,8 @@ def run(self): security = HTTPBasic() fastapi_app.SEND_TASK = None + input_type = self._input_type + @fastapi_app.middleware("http") async def current_request_counter(request: Request, call_next): if not request.scope["path"] == self.endpoint: @@ -281,7 +283,7 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe self._iter = cycle(self.servers) @fastapi_app.post(self.endpoint, response_model=self._output_type) - async def balance_api(inputs: self._input_type): + async def balance_api(inputs: input_type): return await self.process_request(inputs) endpoint_info_page = self._get_endpoint_info_page() @@ -578,9 +580,13 @@ def scale(self, replicas: int, metrics: dict) -> int: The target number of running works. The value will be adjusted after this method runs so that it satisfies ``min_replicas<=replicas<=max_replicas``. """ - pending_requests_per_running_or_pending_work = metrics["pending_requests"] / ( - replicas + metrics["pending_works"] - ) + pending_requests = metrics["pending_requests"] + active_or_pending_works = replicas + metrics["pending_works"] + + if active_or_pending_works == 0: + return 1 if pending_requests > 0 else 0 + + pending_requests_per_running_or_pending_work = pending_requests / active_or_pending_works # scale out if the number of pending requests exceeds max batch size. max_requests_per_work = self.max_batch_size diff --git a/tests/tests_app/components/serve/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py index 6bd5aa958b6bf..52ffa02f162c6 100644 --- a/tests/tests_app/components/serve/test_auto_scaler.py +++ b/tests/tests_app/components/serve/test_auto_scaler.py @@ -93,6 +93,24 @@ def test_scale(replicas, metrics, expected_replicas): assert auto_scaler.scale(replicas, metrics) == expected_replicas +def test_scale_from_zero_min_replica(): + auto_scaler = AutoScaler( + EmptyWork, + min_replicas=0, + max_replicas=2, + max_batch_size=10, + ) + + resp = auto_scaler.scale(0, {"pending_requests": 0, "pending_works": 0}) + assert resp == 0 + + resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 0}) + assert resp == 1 + + resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 1}) + assert resp <= 0 + + def test_create_work_cloud_compute_cloned(): """Test CloudCompute is cloned to avoid creating multiple works in a single machine.""" cloud_compute = CloudCompute("gpu")