Skip to content

Commit

Permalink
feat(common): use patterns to filter out nodes during graph traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Oct 11, 2023
1 parent 652ceab commit 06a41fb
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 194 deletions.
184 changes: 102 additions & 82 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ibis.common.bases import Hashable
from ibis.common.collections import frozendict
from ibis.common.patterns import NoMatch, pattern
from ibis.common.patterns import NoMatch, Pattern, pattern
from ibis.util import experimental

if TYPE_CHECKING:
Expand All @@ -17,7 +17,7 @@
N = TypeVar("N")


def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]:
def _flatten_collections(node: Any) -> Iterator[N]:
"""Flatten collections of nodes into a single iterator.
We treat common collection types inherently traversable (e.g. list, tuple, dict)
Expand All @@ -26,9 +26,7 @@ def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]:
Parameters
----------
node
Flattaneble object unless it's an instance of the types passed as filter.
filter
Type to filter out for the traversal, e.g. Node.
Flattaneble object.
Returns
-------
Expand All @@ -50,21 +48,21 @@ def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]:
>>> c = MyNode(2, "c", (a, b))
>>> d = MyNode(1, "d", (c,))
>>>
>>> assert list(_flatten_collections(a, Node)) == [a]
>>> assert list(_flatten_collections((c,), Node)) == [c]
>>> assert list(_flatten_collections([a, b, (c, a)], Node)) == [a, b, c, a]
>>> assert list(_flatten_collections(a)) == [a]
>>> assert list(_flatten_collections((c,))) == [c]
>>> assert list(_flatten_collections([a, b, (c, a)])) == [a, b, c, a]
"""
if isinstance(node, filter):
if isinstance(node, Node):
yield node
elif isinstance(node, (tuple, list)):
for item in node:
yield from _flatten_collections(item, filter)
yield from _flatten_collections(item)
elif isinstance(node, (dict, frozendict)):
for value in node.values():
yield from _flatten_collections(value, filter)
yield from _flatten_collections(value)


def _recursive_get(obj: Any, dct: dict, filter: type) -> Any:
def _recursive_get(obj: Any, dct: dict) -> Any:
"""Recursively replace objects in a nested structure with values from a dict.
Since we treat common collection types inherently traversable, so we need to
Expand All @@ -76,8 +74,6 @@ def _recursive_get(obj: Any, dct: dict, filter: type) -> Any:
Object to replace.
dct
Mapping of objects to replace with their values.
filter
Type to filter out for the traversal, e.g. Node.
Returns
-------
Expand All @@ -88,19 +84,19 @@ def _recursive_get(obj: Any, dct: dict, filter: type) -> Any:
>>> from ibis.common.graph import _recursive_get
>>>
>>> dct = {1: 2, 3: 4}
>>> _recursive_get((1, 3), dct, filter=int)
>>> _recursive_get((1, 3), dct)
(2, 4)
>>> _recursive_get(frozendict({1: 3}), dct, filter=int)
>>> _recursive_get(frozendict({1: 3}), dct)
{1: 4}
>>> _recursive_get(frozendict({1: (1, 3)}), dct, filter=int)
>>> _recursive_get(frozendict({1: (1, 3)}), dct)
{1: (2, 4)}
"""
if isinstance(obj, filter):
return dct[obj]
if isinstance(obj, Node):
return dct.get(obj, obj)
elif isinstance(obj, (tuple, list)):
return tuple(_recursive_get(o, dct, filter) for o in obj)
return tuple(_recursive_get(o, dct) for o in obj)
elif isinstance(obj, (dict, frozendict)):
return {k: _recursive_get(v, dct, filter) for k, v in obj.items()}
return {k: _recursive_get(v, dct) for k, v in obj.items()}
else:
return obj

Expand All @@ -118,70 +114,52 @@ def __args__(self) -> tuple[Any, ...]:
def __argnames__(self) -> tuple[str, ...]:
"""Sequence of argument names."""

def __children__(self, filter: Optional[type] = None) -> tuple[Node, ...]:
"""Return the children of this node.
This method is used to traverse the Node so it returns the children of the node
in the order they should be traversed. We treat common collection types
inherently traversable (e.g. list, tuple, dict), so this method flattens and
optionally filters the arguments of the node.
Parameters
----------
filter : type, default Node
Type to filter out for the traversal, Node is used by default.
Returns
-------
Child nodes of this node.
"""
return tuple(_flatten_collections(self.__args__, filter or Node))

def __rich_repr__(self):
"""Support for rich reprerentation of the node."""
return zip(self.__argnames__, self.__args__)

def map(self, fn: Callable, filter: Optional[type] = None) -> dict[Node, Any]:
def map(self, fn: Callable, filter: Optional[Any] = None) -> dict[Node, Any]:
"""Apply a function to all nodes in the graph.
The traversal is done in a topological order, so the function receives the
results of its immediate children as keyword arguments.
Parameters
----------
fn : Callable
fn
Function to apply to each node. It receives the node as the first argument,
the results as the second and the results of the children as keyword
arguments.
filter : Optional[type], default None
Type to filter out for the traversal, Node is filtered out by default.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Returns
-------
A mapping of nodes to their results.
"""
filter = filter or Node
results: dict[Node, Any] = {}
for node in Graph.from_bfs(self, filter=filter).toposort():
# minor optimization to directly recurse into the children
kwargs = {
k: _recursive_get(v, results, filter)
k: _recursive_get(v, results)
for k, v in zip(node.__argnames__, node.__args__)
}
results[node] = fn(node, results, **kwargs)
return results

def find(
self, type: type | tuple[type], filter: Optional[type] = None
) -> set[Node]:
def find(self, type: type | tuple[type], filter: Optional[Any] = None) -> set[Node]:
"""Find all nodes of a given type in the graph.
Parameters
----------
type : type | tuple[type]
type
Type or tuple of types to find.
filter : Optional[type], default None
Type to filter out for the traversal, Node is filtered out by default.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Returns
-------
Expand All @@ -192,7 +170,7 @@ def find(

@experimental
def match(
self, pat: Any, filter: Optional[type] = None, context: Optional[dict] = None
self, pat: Any, filter: Optional[Any] = None, context: Optional[dict] = None
) -> set[Node]:
"""Find all nodes matching a given pattern in the graph.
Expand All @@ -201,12 +179,14 @@ def match(
Parameters
----------
pat : Any
pat
Pattern to match. `ibis.common.pattern()` function is used to coerce the
input value into a pattern. See the pattern module for more details.
filter : Optional[type], default None
Type to filter out for the traversal, Node is filtered out by default.
context : Optional[dict], default None
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
context
Optional context to use for the pattern matching.
Returns
Expand Down Expand Up @@ -284,17 +264,19 @@ def __init__(self, mapping=(), /, **kwargs):
super().__init__(mapping, **kwargs)

@classmethod
def from_bfs(cls, root: Node, filter=Node) -> Self:
def from_bfs(cls, root: Node, filter: Optional[Any] = None) -> Self:
"""Construct a graph from a root node using a breadth-first search.
The traversal is implemented in an iterative fashion using a queue.
Parameters
----------
root : Node
root
Root node of the graph.
filter : Optional[type], default None
Type to filter out for the traversal, Node is filtered out by default.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Returns
-------
Expand All @@ -306,25 +288,42 @@ def from_bfs(cls, root: Node, filter=Node) -> Self:
queue = deque([root])
graph = cls()

while queue:
if (node := queue.popleft()) not in graph:
graph[node] = deps = node.__children__(filter)
queue.extend(deps)
if filter is None:
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
while queue:
if (node := queue.popleft()) not in graph:
children = tuple(_flatten_collections(node.__args__))
graph[node] = children
queue.extend(children)
else:
filter = pattern(filter)
while queue:
if (node := queue.popleft()) not in graph:
children = tuple(
child
for child in _flatten_collections(node.__args__)
if filter.match(child, {}) is not NoMatch
)
graph[node] = children
queue.extend(children)

return graph

@classmethod
def from_dfs(cls, root: Node, filter=Node) -> Self:
def from_dfs(cls, root: Node, filter: Optional[Any] = None) -> Self:
"""Construct a graph from a root node using a depth-first search.
The traversal is implemented in an iterative fashion using a stack.
Parameters
----------
root : Node
root
Root node of the graph.
filter : Optional[type], default None
Type to filter out for the traversal, Node is filtered out by default.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Returns
-------
Expand All @@ -336,10 +335,25 @@ def from_dfs(cls, root: Node, filter=Node) -> Self:
stack = deque([root])
graph = dict()

while stack:
if (node := stack.pop()) not in graph:
graph[node] = deps = node.__children__(filter)
stack.extend(deps)
if filter is None:
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
while stack:
if (node := stack.pop()) not in graph:
children = tuple(_flatten_collections(node.__args__))
graph[node] = children
stack.extend(children)
else:
filter = pattern(filter)
while stack:
if (node := stack.pop()) not in graph:
children = tuple(
child
for child in _flatten_collections(node.__args__)
if filter.match(child, {}) is not NoMatch
)
graph[node] = children
stack.extend(children)

return cls(reversed(graph.items()))

Expand Down Expand Up @@ -452,7 +466,7 @@ def toposort(node: Node) -> Graph:
def traverse(
fn: Callable[[Node], tuple[bool | Iterable, Any]],
node: Iterable[Node] | Node,
filter=Node,
filter: Optional[Any] = None,
) -> Iterator[Any]:
"""Utility for generic expression tree traversal.
Expand All @@ -464,33 +478,39 @@ def traverse(
node
The Node expression or a list of expressions.
filter
Restrict initial traversal to this kind of node
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
"""
args = reversed(node) if isinstance(node, Iterable) else [node]
todo: deque[Node] = deque(arg for arg in args if isinstance(arg, filter))

args = reversed(node) if isinstance(node, Sequence) else [node]
todo: deque[Node] = deque(args)
seen: set[Node] = set()
filter: Pattern = pattern(filter or ...)

while todo:
node = todo.pop()

if node in seen:
continue
else:
seen.add(node)
if filter.match(node, {}) is NoMatch:
continue

seen.add(node)

control, result = fn(node)
if result is not None:
yield result

if control is not halt:
if control is proceed:
args = node.__children__(filter)
children = tuple(_flatten_collections(node.__args__))
elif isinstance(control, Iterable):
args = control
children = control
else:
raise TypeError(
"First item of the returned tuple must be "
"an instance of boolean or iterable"
)

todo.extend(reversed(args))
todo.extend(reversed(children))
Loading

0 comments on commit 06a41fb

Please sign in to comment.