-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add subgraph logic to post and pre order traversal #345
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -597,14 +597,17 @@ def construct_from( | |
) | ||
|
||
|
||
def iter_nodes(nodes): | ||
def iter_nodes(nodes, flatten_subgraphs=False): | ||
queue = nodes[:] | ||
while queue: | ||
current = queue.pop() | ||
current = queue.pop(0) | ||
if flatten_subgraphs and current.op.is_subgraph: | ||
new_nodes = iter_nodes([current.op.graph.output_node]) | ||
for node in new_nodes: | ||
if node not in queue: | ||
queue.append(node) | ||
if isinstance(current, list): | ||
queue.extend(current) | ||
elif current.op.is_subgraph: | ||
queue.extend(iter_nodes([current.op.graph.output_node])) | ||
else: | ||
yield current | ||
for node in current.parents_with_dependencies: | ||
|
@@ -613,7 +616,7 @@ def iter_nodes(nodes): | |
|
||
|
||
# output node (bottom) -> selection leaf nodes (top) | ||
def preorder_iter_nodes(nodes): | ||
def preorder_iter_nodes(nodes, flatten_subgraphs=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Noticing that this boolean param is added to multiple functions makes me wonder if there's a missing concept here. Back in the day (before
|
||
queue = [] | ||
if not isinstance(nodes, list): | ||
nodes = [nodes] | ||
|
@@ -624,6 +627,8 @@ def traverse(current_nodes): | |
if node in queue: | ||
queue.remove(node) | ||
|
||
if flatten_subgraphs and node.op.is_subgraph: | ||
queue.extend(list(preorder_iter_nodes(node.op.graph.output_node))) | ||
queue.append(node) | ||
|
||
for node in current_nodes: | ||
|
@@ -635,16 +640,19 @@ def traverse(current_nodes): | |
|
||
|
||
# selection leaf nodes (top) -> output node (bottom) | ||
def postorder_iter_nodes(nodes): | ||
def postorder_iter_nodes(nodes, flatten_subgraphs=False): | ||
queue = [] | ||
if not isinstance(nodes, list): | ||
nodes = [nodes] | ||
|
||
def traverse(current_nodes): | ||
for node in current_nodes: | ||
traverse(node.parents_with_dependencies) | ||
|
||
if node not in queue: | ||
queue.append(node) | ||
if flatten_subgraphs and node.op.is_subgraph: | ||
queue.extend(list(postorder_iter_nodes(node.op.graph.output_node))) | ||
|
||
traverse(nodes) | ||
for node in queue: | ||
|
@@ -663,15 +671,18 @@ def _filter_by_type(elements, type_): | |
return results | ||
|
||
|
||
def _combine_schemas(elements): | ||
def _combine_schemas(elements, input_schemas=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like one approach could be to implement methods like this in order to avoid the boolean flag:
|
||
combined = Schema() | ||
for elem in elements: | ||
if isinstance(elem, Node): | ||
combined += elem.output_schema | ||
if input_schemas: | ||
combined += elem.input_schema | ||
else: | ||
combined += elem.output_schema | ||
elif isinstance(elem, ColumnSelector): | ||
combined += Schema(elem.names) | ||
elif isinstance(elem, list): | ||
combined += _combine_schemas(elem) | ||
combined += _combine_schemas(elem, input_schemas=input_schemas) | ||
return combined | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a good place to use two separate methods (
_combine_input_schemas
andcombine_output_schemas
) instead of a boolean flag