diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index fe494258c1319..b08c7edae7bf7 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed race condition to over-write the frontend with app infos ([#15398](https://github.com/Lightning-AI/lightning/pull/15398)) +- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642)) + + ## [1.8.0] - 2022-11-01 diff --git a/src/lightning_app/cli/commands/logs.py b/src/lightning_app/cli/commands/logs.py index 9d53601da0698..fb0746dd50fff 100644 --- a/src/lightning_app/cli/commands/logs.py +++ b/src/lightning_app/cli/commands/logs.py @@ -71,6 +71,7 @@ def _show_logs(app_name: str, components: List[str], follow: bool) -> None: works = client.lightningwork_service_list_lightningwork( project_id=project.project_id, app_id=apps[app_name].id ).lightningworks + app_component_names = ["flow"] + [f.name for f in apps[app_name].spec.flow_servers] + [w.name for w in works] if not components: diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 03b0ceb26058f..9cdf73bfb7446 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,5 +1,6 @@ import abc import base64 +import os from pathlib import Path from typing import Any, Dict, Optional @@ -14,12 +15,6 @@ logger = Logger(__name__) -def image_to_base64(image_path): - with open(image_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return encoded_string.decode("UTF-8") - - class _DefaultInputData(BaseModel): payload: str @@ -33,7 +28,8 @@ class Image(BaseModel): @staticmethod def _get_sample_data() -> Dict[Any, Any]: - imagepath = Path(__file__).absolute().parent / "catimage.png" + name = "lightning" + "_" + "app" + imagepath = Path(__file__.replace(f"lightning{os.sep}app", name)).absolute().parent / "catimage.png" with open(imagepath, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) return {"image": encoded_string.decode("UTF-8")} diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 17e91d565d6de..9620f4bb96cc6 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -24,7 +24,7 @@ from lightning_app.core.queues import BaseQueue, SingleProcessQueue from lightning_app.core.work import LightningWork from lightning_app.frontend import Frontend -from lightning_app.storage import Drive, Path +from lightning_app.storage import Drive, Path, Payload from lightning_app.storage.path import _storage_root_dir from lightning_app.utilities import frontend from lightning_app.utilities.app_helpers import ( @@ -630,8 +630,16 @@ def _extract_vars_from_component_name(component_name: str, state): else: return None - # Note: Remove private keys - return {k: v for k, v in child["vars"].items() if not k.startswith("_")} + # Filter private keys and drives + return { + k: v + for k, v in child["vars"].items() + if ( + not k.startswith("_") + and not (isinstance(v, dict) and v.get("type", None) == "__drive__") + and not (isinstance(v, (Payload, Path))) + ) + } def _send_flow_to_work_deltas(self, state) -> None: if not self.flow_to_work_delta_queues: @@ -652,10 +660,6 @@ def _send_flow_to_work_deltas(self, state) -> None: if state_work is None or last_state_work is None: continue - # Note: The flow shouldn't update path or drive manually. - last_state_work = apply_to_collection(last_state_work, (Path, Drive), lambda x: None) - state_work = apply_to_collection(state_work, (Path, Drive), lambda x: None) - deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict() if "unprocessed" in deep_diff: diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index f4c8c001acad7..9004f7d1c7302 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -431,7 +431,18 @@ def fetch_logs(component_names: Optional[List[str]] = None) -> Generator: project_id=project.project_id, app_id=app_id, ).lightningworks + component_names = ["flow"] + [w.name for w in works] + else: + + def add_prefix(c: str) -> str: + if c == "flow": + return c + if not c.startswith("root."): + return "root." + c + return c + + component_names = [add_prefix(c) for c in component_names] gen = _app_logs_reader( logs_api_client=logs_api_client, diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index fccbaaa671588..682138d20654e 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -14,7 +14,7 @@ from lightning_app import LightningApp, LightningFlow, LightningWork from lightning_app.runners import MultiProcessRuntime -from lightning_app.storage import Path +from lightning_app.storage import Drive, Path from lightning_app.storage.path import _artifacts_path from lightning_app.storage.requests import _GetRequest from lightning_app.testing.helpers import _MockQueue, EmptyFlow @@ -761,3 +761,31 @@ def test_bi_directional_proxy_forbidden(monkeypatch): MultiProcessRuntime(app, start_server=False).dispatch() assert app.stage == AppStage.FAILED assert "A forbidden operation to update the work" in str(app.exception) + + +class WorkDrive(LightningFlow): + def __init__(self, drive): + super().__init__() + self.drive = drive + self.path = Path("data") + + def run(self): + pass + + +class FlowDrive(LightningFlow): + def __init__(self): + super().__init__() + self.data = Drive("lit://data") + self.counter = 0 + + def run(self): + if not hasattr(self, "w"): + self.w = WorkDrive(self.data) + self.counter += 1 + + +def test_bi_directional_proxy_filtering(): + app = LightningApp(FlowDrive()) + app.root.run() + assert app._extract_vars_from_component_name(app.root.w.name, app.state) == {}