Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

App: Limit rate of requests to http queue #18981

Merged
merged 7 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_lightning_cloud_url() -> str:
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "1"))

USER_ID = os.getenv("USER_ID", "1234")
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
Expand Down
44 changes: 43 additions & 1 deletion src/lightning/app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from lightning.app.core.constants import (
HTTP_QUEUE_REFRESH_INTERVAL,
HTTP_QUEUE_REQUESTS_PER_SECOND,
HTTP_QUEUE_TOKEN,
HTTP_QUEUE_URL,
LIGHTNING_DIR,
Expand Down Expand Up @@ -77,7 +78,9 @@ def get_queue(self, queue_name: str) -> "BaseQueue":
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
if self == QueuingSystem.REDIS:
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
return RateLimitedQueue(
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
)

def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
Expand Down Expand Up @@ -347,6 +350,45 @@ def from_dict(cls, state: dict) -> "RedisQueue":
return cls(**state)


class RateLimitedQueue(BaseQueue):
def __init__(self, queue: BaseQueue, requests_per_second: float):
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.

Args:
queue: The queue to wrap.
requests_per_second: The target number of get or put requests per second.

"""
self.name = queue.name
self.default_timeout = queue.default_timeout

self._queue = queue
self._seconds_per_request = 1 / requests_per_second

self._last_get = 0.0
self._last_put = 0.0

@property
def is_running(self) -> bool:
return self._queue.is_running

def _wait_until_allowed(self, last_time: float) -> None:
t = time.time()
diff = t - last_time
if diff < self._seconds_per_request:
time.sleep(self._seconds_per_request - diff)

def get(self, timeout: Optional[float] = None) -> Any:
self._wait_until_allowed(self._last_get)
self._last_get = time.time()
return self._queue.get(timeout=timeout)

def put(self, item: Any) -> None:
self._wait_until_allowed(self._last_put)
self._last_put = time.time()
return self._queue.put(item)


class HTTPQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float) -> None:
"""
Expand Down
40 changes: 33 additions & 7 deletions tests/tests_app/core/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import requests_mock
from lightning.app import LightningFlow
from lightning.app.core import queues
from lightning.app.core.constants import HTTP_QUEUE_URL
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue
from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
from lightning.app.core.queues import (
READINESS_QUEUE_CONSTANT,
BaseQueue,
HTTPQueue,
QueuingSystem,
RateLimitedQueue,
RedisQueue,
)
from lightning.app.utilities.imports import _is_redis_available
from lightning.app.utilities.redis import check_if_redis_running

Expand Down Expand Up @@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):

class TestHTTPQueue:
def test_http_queue_failure_on_queue_name(self):
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.put("test")

Expand All @@ -174,7 +181,7 @@ def test_http_queue_failure_on_queue_name(self):

def test_http_queue_put(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
test_obj = LightningFlow()

# mocking requests and responses
Expand All @@ -200,8 +207,7 @@ def test_http_queue_put(self, monkeypatch):

def test_http_queue_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")

test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)

Expand All @@ -218,7 +224,7 @@ def test_http_queue_get(self, monkeypatch):
def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")

test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)

resp1 = mock.MagicMock()
resp1.status_code = 204
Expand All @@ -235,3 +241,23 @@ def test_unreachable_queue(monkeypatch):
# Test backoff on queue.put
test_queue.put("foo")
assert test_queue.client.post.call_count == 3


def test_rate_limited_queue():
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
mock_queue = mock.MagicMock()

mock_queue.name = "inner_queue"
mock_queue.default_timeout = 10.0

rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)

assert rate_limited_queue.name == "inner_queue"
assert rate_limited_queue.default_timeout == 10.0

timeout = time.perf_counter() + 1
while time.perf_counter() < timeout:
rate_limited_queue.get()

assert (
mock_queue.get.call_count == 2
), f"the inner queue should have been called exactly twice but was called {mock_queue.get.call_count} times"
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
Loading