Skip to content

Commit

Permalink
App: Limit rate of requests to http queue (#18981)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 10, 2023
1 parent a9d427c commit 9085db4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_lightning_cloud_url() -> str:
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))

USER_ID = os.getenv("USER_ID", "1234")
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
Expand Down
44 changes: 43 additions & 1 deletion src/lightning/app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from lightning.app.core.constants import (
HTTP_QUEUE_REFRESH_INTERVAL,
HTTP_QUEUE_REQUESTS_PER_SECOND,
HTTP_QUEUE_TOKEN,
HTTP_QUEUE_URL,
LIGHTNING_DIR,
Expand Down Expand Up @@ -77,7 +78,9 @@ def get_queue(self, queue_name: str) -> "BaseQueue":
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
if self == QueuingSystem.REDIS:
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
return RateLimitedQueue(
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
)

def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
Expand Down Expand Up @@ -347,6 +350,45 @@ def from_dict(cls, state: dict) -> "RedisQueue":
return cls(**state)


class RateLimitedQueue(BaseQueue):
def __init__(self, queue: BaseQueue, requests_per_second: float):
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.
Args:
queue: The queue to wrap.
requests_per_second: The target number of get or put requests per second.
"""
self.name = queue.name
self.default_timeout = queue.default_timeout

self._queue = queue
self._seconds_per_request = 1 / requests_per_second

self._last_get = 0.0
self._last_put = 0.0

@property
def is_running(self) -> bool:
return self._queue.is_running

def _wait_until_allowed(self, last_time: float) -> None:
t = time.time()
diff = t - last_time
if diff < self._seconds_per_request:
time.sleep(self._seconds_per_request - diff)

def get(self, timeout: Optional[float] = None) -> Any:
self._wait_until_allowed(self._last_get)
self._last_get = time.time()
return self._queue.get(timeout=timeout)

def put(self, item: Any) -> None:
self._wait_until_allowed(self._last_put)
self._last_put = time.time()
return self._queue.put(item)


class HTTPQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float) -> None:
"""
Expand Down
42 changes: 35 additions & 7 deletions tests/tests_app/core/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import requests_mock
from lightning.app import LightningFlow
from lightning.app.core import queues
from lightning.app.core.constants import HTTP_QUEUE_URL
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue
from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
from lightning.app.core.queues import (
READINESS_QUEUE_CONSTANT,
BaseQueue,
HTTPQueue,
QueuingSystem,
RateLimitedQueue,
RedisQueue,
)
from lightning.app.utilities.imports import _is_redis_available
from lightning.app.utilities.redis import check_if_redis_running

Expand Down Expand Up @@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):

class TestHTTPQueue:
def test_http_queue_failure_on_queue_name(self):
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.put("test")

Expand All @@ -174,7 +181,7 @@ def test_http_queue_failure_on_queue_name(self):

def test_http_queue_put(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
test_obj = LightningFlow()

# mocking requests and responses
Expand All @@ -200,8 +207,7 @@ def test_http_queue_put(self, monkeypatch):

def test_http_queue_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")

test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)

Expand All @@ -218,7 +224,7 @@ def test_http_queue_get(self, monkeypatch):
def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")

test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)

resp1 = mock.MagicMock()
resp1.status_code = 204
Expand All @@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch):
# Test backoff on queue.put
test_queue.put("foo")
assert test_queue.client.post.call_count == 3


@mock.patch("lightning.app.core.queues.time.sleep")
def test_rate_limited_queue(mock_sleep):
sleeps = []
mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)

mock_queue = mock.MagicMock()

mock_queue.name = "inner_queue"
mock_queue.default_timeout = 10.0

rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)

assert rate_limited_queue.name == "inner_queue"
assert rate_limited_queue.default_timeout == 10.0

timeout = time.perf_counter() + 1
while time.perf_counter() + sum(sleeps) < timeout:
rate_limited_queue.get()

assert mock_queue.get.call_count == 2

0 comments on commit 9085db4

Please sign in to comment.