Skip to content

Commit

Permalink
Reduce queue fetching (#19856)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
tchaton authored May 9, 2024
1 parent e030727 commit 8453e31
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
12 changes: 8 additions & 4 deletions src/lightning/app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
from lightning.app.core.constants import (
BATCH_DELTA_COUNT,
CHECK_ERROR_QUEUE_INTERVAL,
DEBUG_ENABLED,
FLOW_DURATION_SAMPLES,
FLOW_DURATION_THRESHOLD,
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(

self._last_run_time: float = 0.0
self._run_times: list = []
self._last_check_error_queue: float = 0.0

# Path attributes can't get properly attached during the initialization, because the full name
# is only available after all Flows and Works have been instantiated.
Expand Down Expand Up @@ -318,10 +320,12 @@ def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] =
return []

def check_error_queue(self) -> None:
exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]
if isinstance(exception, Exception):
self.exception = exception
self.stage = AppStage.FAILED
if (time() - self._last_check_error_queue) > CHECK_ERROR_QUEUE_INTERVAL:
exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]
if isinstance(exception, Exception):
self.exception = exception
self.stage = AppStage.FAILED
self._last_check_error_queue = time()

@property
def flows(self) -> List[Union[LightningWork, "LightningFlow"]]:
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def get_lightning_cloud_url() -> str:
LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
LIGHTNING_MODELS_PUBLIC_REGISTRY = "https://lightning.ai/v1/models"
ENABLE_ORCHESTRATOR = bool(int(os.getenv("ENABLE_ORCHESTRATOR", "1")))

LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST")
LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0"))
Expand Down Expand Up @@ -99,6 +100,7 @@ def get_lightning_cloud_url() -> str:
SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"

BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))
CHECK_ERROR_QUEUE_INTERVAL = float(os.getenv("CHECK_ERROR_QUEUE_INTERVAL", "30"))


def enable_multiple_works_in_default_container() -> bool:
Expand Down
21 changes: 11 additions & 10 deletions src/lightning/app/runners/multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,17 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):

_set_flow_context()

storage_orchestrator = StorageOrchestrator(
self.app,
self.app.request_queues,
self.app.response_queues,
self.app.copy_request_queues,
self.app.copy_response_queues,
)
self.threads.append(storage_orchestrator)
storage_orchestrator.setDaemon(True)
storage_orchestrator.start()
if constants.ENABLE_ORCHESTRATOR:
storage_orchestrator = StorageOrchestrator(
self.app,
self.app.request_queues,
self.app.response_queues,
self.app.copy_request_queues,
self.app.copy_response_queues,
)
self.threads.append(storage_orchestrator)
storage_orchestrator.setDaemon(True)
storage_orchestrator.start()

if self.start_server:
self.app.should_publish_changes_to_api = True
Expand Down
24 changes: 24 additions & 0 deletions tests/tests_app/core/test_lightning_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,3 +1188,27 @@ def run(self):
def test_lightning_work_stopped():
app = LightningApp(SimpleWork2())
MultiProcessRuntime(app, start_server=False).dispatch()


class FailedWork(LightningWork):
def run(self):
raise Exception


class CheckErrorQueueLightningApp(LightningApp):
def check_error_queue(self):
super().check_error_queue()


def test_error_queue_check(monkeypatch):
import sys

from lightning.app.core import app as app_module

sys_mock = mock.MagicMock()
monkeypatch.setattr(app_module, "CHECK_ERROR_QUEUE_INTERVAL", 0)
monkeypatch.setattr(sys, "exit", sys_mock)
app = LightningApp(FailedWork())
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.stage == AppStage.FAILED
assert app._last_check_error_queue != 0.0

0 comments on commit 8453e31

Please sign in to comment.