-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[App] Introduce auto scaler (#15769)
* Exlucde __pycache__ in setuptools * Add load balancer example * wip * Update example * rename * remove prints * _LoadBalancer -> LoadBalancer * AutoScaler(work) * change var name * remove locust * Update docs * include autoscaler in api ref * docs typo * docs typo * docs typo * docs typo * remove unused loadtest * remove unused device_type * clean up * clean up * clean up * Add docstring * type * env vars to args * expose an API for users to override to customise autoscaling logic * update example * comment * udpate var name * fix scale mechanism and clean up * Update exampl * ignore mypy * Add test file * . * update impl and update tests * Update changlog * . * revert docs * update test * update state to keep calling 'flow.run()' Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> * Add aiohttp to base requirements * Update docs Co-authored-by: Luca Antiga <luca.antiga@gmail.com> * Use deserializer utility * fake trigger * wip: protect /system/* with basic auth * read password at runtime * Change env var name * import torch as optional * Don't overcreate works * simplify imports * Update example * aiohttp * Add work_args work_kwargs * More docs * remove FIXME * Apply Jirka's suggestions Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean example device * add comment on init threshold value * bad merge * nit: logging format * {in,out}put_schema -> {in,out}put_type * lowercase * docs on seconds * process_time -> processing_time * Dont modify work state from flow * Update tests * worker_url -> endpoint * fix exampl * Fix default scale logic * Fix default scale logic * Fix num_pending_works * Update num_pending_works * Fix bug creating too many works * Remove up/downscale_threshold args * Update example * Add typing * Fix example in docstring * Fix default scale logic * Update src/lightning_app/components/auto_scaler.py Co-authored-by: Noha Alon <nohalon@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename method * rename locvar * Add todo * docs ci * docs ci * asdfafsdasdf pls docs * Apply suggestions from code review Co-authored-by: Ethan Harris <ethanwharris@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * . * doc * Update src/lightning_app/components/auto_scaler.py Co-authored-by: Noha Alon <nohalon@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit 24983a0. * Revert "Update src/lightning_app/components/auto_scaler.py" This reverts commit 56ea78b. * Remove redefinition * Remove load balancer run blocker * raise RuntimeError * remove has_sent * lower the default timeout_batching from 10 to 1 * remove debug * update the default timeout_batching * . * tighten condition * fix endpoint * typo in runtimeerror cond * async lock update severs * add a test * {in,out}put_type typing * Update examples/app_server_with_auto_scaler/app.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> * Update .actions/setup_tools.py Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noha Alon <nohalon@gmail.com> Co-authored-by: Ethan Harris <ethanwharris@gmail.com> Co-authored-by: Akihiro Nitta <aki@pop-os.localdomain> Co-authored-by: thomas chaton <thomas@grid.ai> (cherry picked from commit 64b19fb)
1 parent
d90f624
commit 6bdb7cb
Showing
9 changed files
with
754 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Any, List | ||
|
||
import torch | ||
import torchvision | ||
from pydantic import BaseModel | ||
|
||
import lightning as L | ||
|
||
|
||
class RequestModel(BaseModel): | ||
image: str # bytecode | ||
|
||
|
||
class BatchRequestModel(BaseModel): | ||
inputs: List[RequestModel] | ||
|
||
|
||
class BatchResponse(BaseModel): | ||
outputs: List[Any] | ||
|
||
|
||
class PyTorchServer(L.app.components.PythonServer): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__( | ||
port=L.app.utilities.network.find_free_network_port(), | ||
input_type=BatchRequestModel, | ||
output_type=BatchResponse, | ||
cloud_compute=L.CloudCompute("gpu"), | ||
) | ||
|
||
def setup(self): | ||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
self._model = torchvision.models.resnet18(pretrained=True).to(self._device) | ||
|
||
def predict(self, requests: BatchRequestModel): | ||
transforms = torchvision.transforms.Compose( | ||
[ | ||
torchvision.transforms.Resize(224), | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
] | ||
) | ||
images = [] | ||
for request in requests.inputs: | ||
image = L.app.components.serve.types.image.Image.deserialize(request.image) | ||
image = transforms(image).unsqueeze(0) | ||
images.append(image) | ||
images = torch.cat(images) | ||
images = images.to(self._device) | ||
predictions = self._model(images) | ||
results = predictions.argmax(1).cpu().numpy().tolist() | ||
return BatchResponse(outputs=[{"prediction": pred} for pred in results]) | ||
|
||
|
||
class MyAutoScaler(L.app.components.AutoScaler): | ||
def scale(self, replicas: int, metrics: dict) -> int: | ||
"""The default scaling logic that users can override.""" | ||
# scale out if the number of pending requests exceeds max batch size. | ||
max_requests_per_work = self.max_batch_size | ||
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / ( | ||
replicas + metrics["pending_works"] | ||
) | ||
if pending_requests_per_running_or_pending_work >= max_requests_per_work: | ||
return replicas + 1 | ||
|
||
# scale in if the number of pending requests is below 25% of max_requests_per_work | ||
min_requests_per_work = max_requests_per_work * 0.25 | ||
pending_requests_per_running_work = metrics["pending_requests"] / replicas | ||
if pending_requests_per_running_work < min_requests_per_work: | ||
return replicas - 1 | ||
|
||
return replicas | ||
|
||
|
||
app = L.LightningApp( | ||
MyAutoScaler( | ||
PyTorchServer, | ||
min_replicas=2, | ||
max_replicas=4, | ||
autoscale_interval=10, | ||
endpoint="predict", | ||
input_type=RequestModel, | ||
output_type=Any, | ||
timeout_batching=1, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ inquirer>=2.10.0 | |
psutil<5.9.4 | ||
click<=8.1.3 | ||
s3fs>=2022.5.0, <2022.8.3 | ||
aiohttp>=3.8.0, <=3.8.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import time | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from lightning_app import LightningWork | ||
from lightning_app.components import AutoScaler | ||
|
||
|
||
class EmptyWork(LightningWork): | ||
def run(self): | ||
pass | ||
|
||
|
||
class AutoScaler1(AutoScaler): | ||
def scale(self, replicas: int, metrics) -> int: | ||
# only upscale | ||
return replicas + 1 | ||
|
||
|
||
class AutoScaler2(AutoScaler): | ||
def scale(self, replicas: int, metrics) -> int: | ||
# only downscale | ||
return replicas - 1 | ||
|
||
|
||
def test_num_replicas_after_init(): | ||
"""Test the number of works is the same as min_replicas after initialization.""" | ||
min_replicas = 2 | ||
auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas) | ||
assert auto_scaler.num_replicas == min_replicas | ||
|
||
|
||
@patch("uvicorn.run") | ||
@patch("lightning_app.components.auto_scaler._LoadBalancer.url") | ||
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") | ||
def test_num_replicas_not_above_max_replicas(*_): | ||
"""Test self.num_replicas doesn't exceed max_replicas.""" | ||
max_replicas = 6 | ||
auto_scaler = AutoScaler1( | ||
EmptyWork, | ||
min_replicas=1, | ||
max_replicas=max_replicas, | ||
autoscale_interval=0.001, | ||
) | ||
|
||
for _ in range(max_replicas + 1): | ||
time.sleep(0.002) | ||
auto_scaler.run() | ||
|
||
assert auto_scaler.num_replicas == max_replicas | ||
|
||
|
||
@patch("uvicorn.run") | ||
@patch("lightning_app.components.auto_scaler._LoadBalancer.url") | ||
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") | ||
def test_num_replicas_not_belo_min_replicas(*_): | ||
"""Test self.num_replicas doesn't exceed max_replicas.""" | ||
min_replicas = 1 | ||
auto_scaler = AutoScaler2( | ||
EmptyWork, | ||
min_replicas=min_replicas, | ||
max_replicas=4, | ||
autoscale_interval=0.001, | ||
) | ||
|
||
for _ in range(3): | ||
time.sleep(0.002) | ||
auto_scaler.run() | ||
|
||
assert auto_scaler.num_replicas == min_replicas | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"replicas, metrics, expected_replicas", | ||
[ | ||
pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"), | ||
pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"), | ||
pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"), | ||
pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"), | ||
], | ||
) | ||
def test_scale(replicas, metrics, expected_replicas): | ||
"""Test `scale()`, the default scaling strategy.""" | ||
auto_scaler = AutoScaler( | ||
EmptyWork, | ||
min_replicas=1, | ||
max_replicas=8, | ||
max_batch_size=1, | ||
) | ||
|
||
assert auto_scaler.scale(replicas, metrics) == expected_replicas |