Skip to content

Commit

Permalink
[App] Resolve inconsistency where the flow.flows property isn't rec…
Browse files Browse the repository at this point in the history
…ursive leading to flow overrides (#15466)

* update

* update

* update

* update

* update

* resolve attachment

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
tchaton authored Nov 3, 2022
1 parent 1c26c41 commit 921dc1c
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 160 deletions.
3 changes: 2 additions & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Changed the `flow.flows` to be recursive wont to align the behavior with the `flow.works` ([#15466](https://github.com/Lightning-AI/lightning/pull/15466))

-

Expand All @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- Fixed writing app name and id in connect.txt file for the command CLI ([#15443](https://github.com/Lightning-AI/lightning/pull/15443))

-
Expand Down
1 change: 1 addition & 0 deletions src/lightning_app/cli/commands/app_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _handle_command_without_client(command: str, metadata: Dict, url: str) -> No
query_parameters = "&".join(provided_params)
resp = requests.post(url + f"/command/{command}?{query_parameters}")
assert resp.status_code == 200, resp.json()
print(resp.json())


def _handle_command_with_client(command: str, metadata: Dict, app_name: str, app_id: Optional[str], url: str):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def check_error_queue(self) -> None:
@property
def flows(self) -> List["LightningFlow"]:
"""Returns all the flows defined within this application."""
return [self.root] + self.root.get_all_children()
return list(self.root.flows.values())

@property
def works(self) -> List[LightningWork]:
Expand Down
57 changes: 18 additions & 39 deletions src/lightning_app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,19 @@ def _attach_backend(flow: "LightningFlow", backend):
"""Attach the backend to all flows and its children."""
flow._backend = backend

for child_flow in flow.flows.values():
LightningFlow._attach_backend(child_flow, backend)

for struct_name in flow._structures:
structure = getattr(flow, struct_name)
for flow in structure.flows:
LightningFlow._attach_backend(flow, backend)
for work in structure.works:
backend._wrap_run_method(_LightningAppRef().get_current(), work)
work._backend = backend

for name in flow._structures:
getattr(flow, name)._backend = backend

for work in flow.works(recurse=False):
backend._wrap_run_method(_LightningAppRef().get_current(), work)
for child_flow in flow.flows.values():
child_flow._backend = backend
for name in child_flow._structures:
getattr(child_flow, name)._backend = backend

app = _LightningAppRef().get_current()

for child_work in flow.works():
child_work._backend = backend
backend._wrap_run_method(app, child_work)

def __getattr__(self, item):
if item in self.__dict__.get("_paths", {}):
Expand Down Expand Up @@ -274,12 +271,15 @@ def state_with_changes(self):
}

@property
def flows(self):
def flows(self) -> Dict[str, "LightningFlow"]:
"""Return its children LightningFlow."""
flows = {el: getattr(self, el) for el in sorted(self._flows)}
flows = {}
for el in sorted(self._flows):
flow = getattr(self, el)
flows[flow.name] = flow
flows.update(flow.flows)
for struct_name in sorted(self._structures):
for flow in getattr(self, struct_name).flows:
flows[flow.name] = flow
flows.update(getattr(self, struct_name).flows)
return flows

def works(self, recurse: bool = True) -> List[LightningWork]:
Expand All @@ -297,28 +297,7 @@ def works(self, recurse: bool = True) -> List[LightningWork]:

def named_works(self, recurse: bool = True) -> List[Tuple[str, LightningWork]]:
"""Return its :class:`~lightning_app.core.work.LightningWork` with their names."""
named_works = [(el, getattr(self, el)) for el in sorted(self._works)]
if not recurse:
return named_works
for child_name in sorted(self._flows):
for w in getattr(self, child_name).works(recurse=recurse):
named_works.append(w)
for struct_name in sorted(self._structures):
for w in getattr(self, struct_name).works:
named_works.append((w.name, w))
return named_works

def get_all_children_(self, children):
sorted_children = sorted(self._flows)
children.extend([getattr(self, el) for el in sorted_children])
for child in sorted_children:
getattr(self, child).get_all_children_(children)
return children

def get_all_children(self):
children = []
self.get_all_children_(children)
return children
return [(w.name, w) for w in self.works(recurse=recurse)]

def set_state(self, provided_state: Dict, recurse: bool = True) -> None:
"""Method to set the state to this LightningFlow, its children and
Expand Down
16 changes: 16 additions & 0 deletions src/lightning_app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def __init__(
self.redis = redis.Redis(host=host, port=port, password=password)

def put(self, item: Any) -> None:
from lightning_app import LightningWork

is_work = isinstance(item, LightningWork)

# TODO: Be careful to handle with a lock if another thread needs
# to access the work backend one day.
# The backend isn't picklable
# Raises a TypeError: cannot pickle '_thread.RLock' object
if is_work:
backend = item._backend
item._backend = None

value = pickle.dumps(item)
queue_len = self.length()
if queue_len >= WARNING_QUEUE_SIZE:
Expand All @@ -252,6 +264,10 @@ def put(self, item: Any) -> None:
"If the issue persists, please contact [email protected]"
)

# The backend isn't pickable.
if is_work:
item._backend = backend

def get(self, timeout: int = None):
"""Returns the left most element of the redis queue.
Expand Down
17 changes: 11 additions & 6 deletions src/lightning_app/structures/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ def works(self):
@property
def flows(self):
from lightning_app import LightningFlow

flows = []
for flow in [item for item in self.values() if isinstance(item, LightningFlow)]:
flows.append(flow)
for child_flow in flow.flows:
flows.append(child_flow)
from lightning_app.structures import Dict, List

flows = {}
for item in self.values():
if isinstance(item, LightningFlow):
flows[item.name] = item
for child_flow in item.flows.values():
flows[child_flow.name] = child_flow
if isinstance(item, (Dict, List)):
for child_flow in item.flows.values():
flows[child_flow.name] = child_flow
return flows

@property
Expand Down
17 changes: 11 additions & 6 deletions src/lightning_app/structures/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,17 @@ def works(self):
@property
def flows(self):
from lightning_app import LightningFlow

flows = []
for flow in [item for item in self if isinstance(item, LightningFlow)]:
flows.append(flow)
for child_flow in flow.flows:
flows.append(child_flow)
from lightning_app.structures import Dict, List

flows = {}
for item in self:
if isinstance(item, LightningFlow):
flows[item.name] = item
for child_flow in item.flows.values():
flows[child_flow.name] = child_flow
if isinstance(item, (Dict, List)):
for child_flow in item.flows.values():
flows[child_flow.name] = child_flow
return flows

@property
Expand Down
8 changes: 5 additions & 3 deletions src/lightning_app/utilities/app_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,12 @@ def _set_child_name(component: "Component", child: "Component", new_name: str) -

# the name changed, so recursively update the names of the children of this child
if isinstance(child, lightning_app.core.LightningFlow):
for n, c in child.flows.items():
for n in child._flows:
c = getattr(child, n)
_set_child_name(child, c, n)
for n in child._works:
c = getattr(child, n)
_set_child_name(child, c, n)
for n, w in child.named_works(recurse=False):
_set_child_name(child, w, n)
for n in child._structures:
s = getattr(child, n)
_set_child_name(child, s, n)
Expand Down
24 changes: 13 additions & 11 deletions src/lightning_app/utilities/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ def breadth_first(root: "Component", types: Type["ComponentTuple"] = None):
yield from _BreadthFirstVisitor(root, types)


def depth_first(root: "Component", types: Type["ComponentTuple"] = None):
"""Returns a generator that walks through the tree of components depth-first.
Arguments:
root: The root component of the tree
types: If provided, only the component types in this list will be visited.
"""
yield from _DepthFirstVisitor(root, types)


class _BreadthFirstVisitor:
def __init__(self, root: "Component", types: Type["ComponentTuple"] = None) -> None:
self.queue = [root]
Expand All @@ -36,11 +26,23 @@ def __iter__(self):
return self

def __next__(self):
from lightning_app.structures import Dict

while self.queue:
component = self.queue.pop(0)

if isinstance(component, lightning_app.LightningFlow):
self.queue += list(component.flows.values())
components = [getattr(component, el) for el in sorted(component._flows)]
for struct_name in sorted(component._structures):
structure = getattr(component, struct_name)
if isinstance(structure, Dict):
values = sorted(structure.items(), key=lambda x: x[0])
else:
values = sorted(((v.name, v) for v in structure), key=lambda x: x[0])
for _, value in values:
if isinstance(value, lightning_app.LightningFlow):
components.append(value)
self.queue += components
self.queue += component.works(recurse=False)

if any(isinstance(component, t) for t in self.types):
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_app/core/lightning_app/test_configure_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def __init__(self):
root = TestContentComponent()
LightningApp(root)
assert root._layout == [
dict(name="component0", content="root.component0"),
dict(name="component1", content="root.component1"),
dict(name="component2", content="root.component2"),
dict(name="root.component0", content="root.component0"),
dict(name="root.component1", content="root.component1"),
dict(name="root.component2", content="root.component2"),
]


Expand Down
81 changes: 80 additions & 1 deletion tests/tests_app/core/test_lightning_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
from lightning_app.storage import Path
from lightning_app.storage.path import _storage_root_dir
from lightning_app.structures import Dict as LDict
from lightning_app.structures import List as LList
from lightning_app.testing.helpers import EmptyFlow, EmptyWork
from lightning_app.utilities.app_helpers import (
_delta_to_app_state_delta,
Expand Down Expand Up @@ -307,7 +309,7 @@ def run(self):
self._exit()

flow_a = Flow_A()
assert flow_a.named_works() == [("work_a", flow_a.work_a), ("work_b", flow_a.work_b)]
assert flow_a.named_works() == [("root.work_a", flow_a.work_a), ("root.work_b", flow_a.work_b)]
assert flow_a.works() == [flow_a.work_a, flow_a.work_b]
state = {
"vars": {"counter": 0, "_layout": ANY, "_paths": {}},
Expand Down Expand Up @@ -780,3 +782,80 @@ def test_lightning_flow_reload():
flow = RootFlowReload2()
with pytest.raises(ValueError, match="The component flow_2 wasn't instantiated for the component root"):
_load_state_dict(flow, state)


class NestedFlow(LightningFlow):
def __init__(self):
super().__init__()
self.flows_dict = LDict(**{"a": EmptyFlow()})
self.flows_list = LList(*[EmptyFlow()])
self.flow = EmptyFlow()
assert list(self.flows) == ["root.flow", "root.flows_dict.a", "root.flows_list.0"]
self.w = EmptyWork()

def run(self):
pass


class FlowNested2(LightningFlow):
def __init__(self):
super().__init__()
self.flow3 = EmptyFlow()
self.w = EmptyWork()

def run(self):
pass


class FlowCollection(LightningFlow):
def __init__(self):
super().__init__()
self.flow = EmptyFlow()
assert self.flow.name == "root.flow"
self.flow2 = FlowNested2()
assert list(self.flow2.flows) == ["root.flow2.flow3"]
self.flows_dict = LDict(**{"a": NestedFlow()})
assert list(self.flows_dict.flows) == [
"root.flows_dict.a",
"root.flows_dict.a.flow",
"root.flows_dict.a.flows_dict.a",
"root.flows_dict.a.flows_list.0",
]
self.flows_list = LList(*[NestedFlow()])
assert list(self.flows_list.flows) == [
"root.flows_list.0",
"root.flows_list.0.flow",
"root.flows_list.0.flows_dict.a",
"root.flows_list.0.flows_list.0",
]
self.w = EmptyWork()

def run(self):
pass


def test_lightning_flow_flows_and_works():

flow = FlowCollection()
app = LightningApp(flow)

assert list(app.root.flows.keys()) == [
"root.flow",
"root.flow2",
"root.flow2.flow3",
"root.flows_dict.a",
"root.flows_dict.a.flow",
"root.flows_dict.a.flows_dict.a",
"root.flows_dict.a.flows_list.0",
"root.flows_list.0",
"root.flows_list.0.flow",
"root.flows_list.0.flows_dict.a",
"root.flows_list.0.flows_list.0",
]

assert [w[0] for w in app.root.named_works()] == [
"root.w",
"root.flow2.w",
"root.flows_dict.a.w",
"root.flows_list.0.w",
]
Loading

0 comments on commit 921dc1c

Please sign in to comment.