diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py index 5a303b7d9ccda..cc23ebd645c24 100644 --- a/src/lightning/app/core/constants.py +++ b/src/lightning/app/core/constants.py @@ -53,6 +53,7 @@ def get_lightning_cloud_url() -> str: HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801") HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1")) HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None) +HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5")) USER_ID = os.getenv("USER_ID", "1234") FRONTEND_DIR = str(Path(__file__).parent.parent / "ui") diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index c12ab77245fe0..da941ae72503e 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -29,6 +29,7 @@ from lightning.app.core.constants import ( HTTP_QUEUE_REFRESH_INTERVAL, + HTTP_QUEUE_REQUESTS_PER_SECOND, HTTP_QUEUE_TOKEN, HTTP_QUEUE_URL, LIGHTNING_DIR, @@ -77,7 +78,9 @@ def get_queue(self, queue_name: str) -> "BaseQueue": return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) if self == QueuingSystem.REDIS: return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT) - return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) + return RateLimitedQueue( + HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND + ) def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT @@ -347,6 +350,45 @@ def from_dict(cls, state: dict) -> "RedisQueue": return cls(**state) +class RateLimitedQueue(BaseQueue): + def __init__(self, queue: BaseQueue, requests_per_second: float): + """This is a queue wrapper that will block on get or put calls if they are made too quickly. + + Args: + queue: The queue to wrap. + requests_per_second: The target number of get or put requests per second. + + """ + self.name = queue.name + self.default_timeout = queue.default_timeout + + self._queue = queue + self._seconds_per_request = 1 / requests_per_second + + self._last_get = 0.0 + self._last_put = 0.0 + + @property + def is_running(self) -> bool: + return self._queue.is_running + + def _wait_until_allowed(self, last_time: float) -> None: + t = time.time() + diff = t - last_time + if diff < self._seconds_per_request: + time.sleep(self._seconds_per_request - diff) + + def get(self, timeout: Optional[float] = None) -> Any: + self._wait_until_allowed(self._last_get) + self._last_get = time.time() + return self._queue.get(timeout=timeout) + + def put(self, item: Any) -> None: + self._wait_until_allowed(self._last_put) + self._last_put = time.time() + return self._queue.put(item) + + class HTTPQueue(BaseQueue): def __init__(self, name: str, default_timeout: float) -> None: """ diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py index 7292e96c2f1be..583e828b12430 100644 --- a/tests/tests_app/core/test_queues.py +++ b/tests/tests_app/core/test_queues.py @@ -8,8 +8,15 @@ import requests_mock from lightning.app import LightningFlow from lightning.app.core import queues -from lightning.app.core.constants import HTTP_QUEUE_URL -from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue +from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT +from lightning.app.core.queues import ( + READINESS_QUEUE_CONSTANT, + BaseQueue, + HTTPQueue, + QueuingSystem, + RateLimitedQueue, + RedisQueue, +) from lightning.app.utilities.imports import _is_redis_available from lightning.app.utilities.redis import check_if_redis_running @@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock): class TestHTTPQueue: def test_http_queue_failure_on_queue_name(self): - test_queue = QueuingSystem.HTTP.get_queue(queue_name="test") + test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT) with pytest.raises(ValueError, match="App ID couldn't be extracted"): test_queue.put("test") @@ -174,7 +181,7 @@ def test_http_queue_failure_on_queue_name(self): def test_http_queue_put(self, monkeypatch): monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") - test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") + test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT) test_obj = LightningFlow() # mocking requests and responses @@ -200,8 +207,7 @@ def test_http_queue_put(self, monkeypatch): def test_http_queue_get(self, monkeypatch): monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") - test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") - + test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT) adapter = requests_mock.Adapter() test_queue.client.session.mount("http://", adapter) @@ -218,7 +224,7 @@ def test_http_queue_get(self, monkeypatch): def test_unreachable_queue(monkeypatch): monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") - test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") + test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT) resp1 = mock.MagicMock() resp1.status_code = 204 @@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch): # Test backoff on queue.put test_queue.put("foo") assert test_queue.client.post.call_count == 3 + + +@mock.patch("lightning.app.core.queues.time.sleep") +def test_rate_limited_queue(mock_sleep): + sleeps = [] + mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time) + + mock_queue = mock.MagicMock() + + mock_queue.name = "inner_queue" + mock_queue.default_timeout = 10.0 + + rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1) + + assert rate_limited_queue.name == "inner_queue" + assert rate_limited_queue.default_timeout == 10.0 + + timeout = time.perf_counter() + 1 + while time.perf_counter() + sum(sleeps) < timeout: + rate_limited_queue.get() + + assert mock_queue.get.call_count == 2