Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Resolve bi-directional queue bug #15642

Merged
merged 15 commits into from
Nov 11, 2022
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed race condition to over-write the frontend with app infos ([#15398](https://github.com/Lightning-AI/lightning/pull/15398))


- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))



## [1.8.0] - 2022-11-01

Expand Down
11 changes: 4 additions & 7 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import base64
import sys
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -14,12 +15,6 @@
logger = Logger(__name__)


def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")


class _DefaultInputData(BaseModel):
payload: str

Expand All @@ -33,7 +28,9 @@ class Image(BaseModel):

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
sep = "\\" if sys.platform == "win32" else "/"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
name = "lightning" + "_" + "app"
imagepath = Path(__file__.replace(f"lightning{sep}app", name)).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
Expand Down
18 changes: 11 additions & 7 deletions src/lightning_app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning_app.core.queues import BaseQueue, SingleProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.frontend import Frontend
from lightning_app.storage import Drive, Path
from lightning_app.storage import Drive, Path, Payload
from lightning_app.storage.path import _storage_root_dir
from lightning_app.utilities import frontend
from lightning_app.utilities.app_helpers import (
Expand Down Expand Up @@ -630,8 +630,16 @@ def _extract_vars_from_component_name(component_name: str, state):
else:
return None

# Note: Remove private keys
return {k: v for k, v in child["vars"].items() if not k.startswith("_")}
# Filter private keys and drives
return {
k: v
for k, v in child["vars"].items()
if (
not k.startswith("_")
and not (isinstance(v, dict) and v.get("type", None) == "__drive__")
and not (isinstance(v, (Payload, Path)))
)
}

def _send_flow_to_work_deltas(self, state) -> None:
if not self.flow_to_work_delta_queues:
Expand All @@ -652,10 +660,6 @@ def _send_flow_to_work_deltas(self, state) -> None:
if state_work is None or last_state_work is None:
continue

# Note: The flow shouldn't update path or drive manually.
last_state_work = apply_to_collection(last_state_work, (Path, Drive), lambda x: None)
state_work = apply_to_collection(state_work, (Path, Drive), lambda x: None)

deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict()

if "unprocessed" in deep_diff:
Expand Down
30 changes: 29 additions & 1 deletion tests/tests_app/utilities/test_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
from lightning_app.storage import Drive, Path
from lightning_app.storage.path import _artifacts_path
from lightning_app.storage.requests import _GetRequest
from lightning_app.testing.helpers import _MockQueue, EmptyFlow
Expand Down Expand Up @@ -761,3 +761,31 @@ def test_bi_directional_proxy_forbidden(monkeypatch):
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.stage == AppStage.FAILED
assert "A forbidden operation to update the work" in str(app.exception)


class WorkDrive(LightningFlow):
def __init__(self, drive):
super().__init__()
self.drive = drive
self.path = Path("data")

def run(self):
pass


class FlowDrive(LightningFlow):
def __init__(self):
super().__init__()
self.data = Drive("lit://data")
self.counter = 0

def run(self):
if not hasattr(self, "w"):
self.w = WorkDrive(self.data)
self.counter += 1


def test_bi_directional_proxy_filtering():
app = LightningApp(FlowDrive())
app.root.run()
assert app._extract_vars_from_component_name(app.root.w.name, app.state) == {}