Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add subgraph logic to post and pre order traversal #345

Merged
merged 3 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion merlin/dag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def _validate_node_schemas(self, root_schema, nodes, strict_dtypes=False):
@property
def input_schema(self):
# leaf_node input and output schemas are the same (aka selection)
return _combine_schemas(self.leaf_nodes)
# subgraphs can also be leaf nodes now, so input and output are different
return _combine_schemas(self.leaf_nodes, input_schemas=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good place to use two separate methods (_combine_input_schemas and combine_output_schemas) instead of a boolean flag


@property
def leaf_nodes(self):
Expand Down
29 changes: 20 additions & 9 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,14 +597,17 @@ def construct_from(
)


def iter_nodes(nodes):
def iter_nodes(nodes, flatten_subgraphs=False):
queue = nodes[:]
while queue:
current = queue.pop()
current = queue.pop(0)
if flatten_subgraphs and current.op.is_subgraph:
new_nodes = iter_nodes([current.op.graph.output_node])
for node in new_nodes:
if node not in queue:
queue.append(node)
if isinstance(current, list):
queue.extend(current)
elif current.op.is_subgraph:
queue.extend(iter_nodes([current.op.graph.output_node]))
else:
yield current
for node in current.parents_with_dependencies:
Expand All @@ -613,7 +616,7 @@ def iter_nodes(nodes):


# output node (bottom) -> selection leaf nodes (top)
def preorder_iter_nodes(nodes):
def preorder_iter_nodes(nodes, flatten_subgraphs=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticing that this boolean param is added to multiple functions makes me wonder if there's a missing concept here. Back in the day (before Graph, Node, etc existed and everything was just a ColumnGroup), it probably made more sense to have a plain function instead of a method, but now this is starting to look like either:

  • we're iterating through two different kinds of things that require different iteration behavior and could use polymorphism to achieve that
  • we're doing two different kinds of iteration over the same kind of thing and we could use two methods to capture those behaviors

queue = []
if not isinstance(nodes, list):
nodes = [nodes]
Expand All @@ -624,6 +627,8 @@ def traverse(current_nodes):
if node in queue:
queue.remove(node)

if flatten_subgraphs and node.op.is_subgraph:
queue.extend(list(preorder_iter_nodes(node.op.graph.output_node)))
queue.append(node)

for node in current_nodes:
Expand All @@ -635,16 +640,19 @@ def traverse(current_nodes):


# selection leaf nodes (top) -> output node (bottom)
def postorder_iter_nodes(nodes):
def postorder_iter_nodes(nodes, flatten_subgraphs=False):
queue = []
if not isinstance(nodes, list):
nodes = [nodes]

def traverse(current_nodes):
for node in current_nodes:
traverse(node.parents_with_dependencies)

if node not in queue:
queue.append(node)
if flatten_subgraphs and node.op.is_subgraph:
queue.extend(list(postorder_iter_nodes(node.op.graph.output_node)))

traverse(nodes)
for node in queue:
Expand All @@ -663,15 +671,18 @@ def _filter_by_type(elements, type_):
return results


def _combine_schemas(elements):
def _combine_schemas(elements, input_schemas=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like one approach could be to implement methods like this in order to avoid the boolean flag:

  • `_combine_input_schemas(elements)
  • `_combine_output_schemas(elements)
  • `_combine_schemas(schemas: List[Schema])

combined = Schema()
for elem in elements:
if isinstance(elem, Node):
combined += elem.output_schema
if input_schemas:
combined += elem.input_schema
else:
combined += elem.output_schema
elif isinstance(elem, ColumnSelector):
combined += Schema(elem.names)
elif isinstance(elem, list):
combined += _combine_schemas(elem)
combined += _combine_schemas(elem, input_schemas=input_schemas)
return combined


Expand Down
11 changes: 10 additions & 1 deletion tests/unit/dag/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
import pytest

from merlin.dag import Graph, Node
from merlin.dag import Graph, Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.base_operator import BaseOperator
from merlin.dag.ops.subgraph import Subgraph
from merlin.dag.selector import ColumnSelector
Expand Down Expand Up @@ -86,3 +86,12 @@ def test_subgraph_with_summed_subgraphs():
assert graph.subgraph("combined1").output_node == combined1
assert graph.subgraph("combined2").output_node == combined2
assert graph.subgraph("combined3").output_node == combined3

post_len = len(list(postorder_iter_nodes(graph.output_node)))
pre_len = len(list(preorder_iter_nodes(graph.output_node)))
iter_node_list = list(iter_nodes([graph.output_node]))
iter_len = len(iter_node_list)

assert post_len == pre_len
assert iter_len == post_len
assert iter_len == pre_len