diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 3f656417352f2..6db649a2940be 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -8,7 +8,7 @@ from ibis.common.bases import Hashable from ibis.common.collections import frozendict -from ibis.common.patterns import NoMatch, pattern +from ibis.common.patterns import NoMatch, Pattern, pattern from ibis.util import experimental if TYPE_CHECKING: @@ -17,7 +17,7 @@ N = TypeVar("N") -def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]: +def _flatten_collections(node: Any) -> Iterator[N]: """Flatten collections of nodes into a single iterator. We treat common collection types inherently traversable (e.g. list, tuple, dict) @@ -26,9 +26,7 @@ def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]: Parameters ---------- node - Flattaneble object unless it's an instance of the types passed as filter. - filter - Type to filter out for the traversal, e.g. Node. + Flattaneble object. Returns ------- @@ -50,21 +48,21 @@ def _flatten_collections(node: Any, filter: type[N]) -> Iterator[N]: >>> c = MyNode(2, "c", (a, b)) >>> d = MyNode(1, "d", (c,)) >>> - >>> assert list(_flatten_collections(a, Node)) == [a] - >>> assert list(_flatten_collections((c,), Node)) == [c] - >>> assert list(_flatten_collections([a, b, (c, a)], Node)) == [a, b, c, a] + >>> assert list(_flatten_collections(a)) == [a] + >>> assert list(_flatten_collections((c,))) == [c] + >>> assert list(_flatten_collections([a, b, (c, a)])) == [a, b, c, a] """ - if isinstance(node, filter): + if isinstance(node, Node): yield node elif isinstance(node, (tuple, list)): for item in node: - yield from _flatten_collections(item, filter) + yield from _flatten_collections(item) elif isinstance(node, (dict, frozendict)): for value in node.values(): - yield from _flatten_collections(value, filter) + yield from _flatten_collections(value) -def _recursive_get(obj: Any, dct: dict, filter: type) -> Any: +def _recursive_get(obj: Any, dct: dict) -> Any: """Recursively replace objects in a nested structure with values from a dict. Since we treat common collection types inherently traversable, so we need to @@ -76,8 +74,6 @@ def _recursive_get(obj: Any, dct: dict, filter: type) -> Any: Object to replace. dct Mapping of objects to replace with their values. - filter - Type to filter out for the traversal, e.g. Node. Returns ------- @@ -88,19 +84,19 @@ def _recursive_get(obj: Any, dct: dict, filter: type) -> Any: >>> from ibis.common.graph import _recursive_get >>> >>> dct = {1: 2, 3: 4} - >>> _recursive_get((1, 3), dct, filter=int) + >>> _recursive_get((1, 3), dct) (2, 4) - >>> _recursive_get(frozendict({1: 3}), dct, filter=int) + >>> _recursive_get(frozendict({1: 3}), dct) {1: 4} - >>> _recursive_get(frozendict({1: (1, 3)}), dct, filter=int) + >>> _recursive_get(frozendict({1: (1, 3)}), dct) {1: (2, 4)} """ - if isinstance(obj, filter): - return dct[obj] + if isinstance(obj, Node): + return dct.get(obj, obj) elif isinstance(obj, (tuple, list)): - return tuple(_recursive_get(o, dct, filter) for o in obj) + return tuple(_recursive_get(o, dct) for o in obj) elif isinstance(obj, (dict, frozendict)): - return {k: _recursive_get(v, dct, filter) for k, v in obj.items()} + return {k: _recursive_get(v, dct) for k, v in obj.items()} else: return obj @@ -118,30 +114,11 @@ def __args__(self) -> tuple[Any, ...]: def __argnames__(self) -> tuple[str, ...]: """Sequence of argument names.""" - def __children__(self, filter: Optional[type] = None) -> tuple[Node, ...]: - """Return the children of this node. - - This method is used to traverse the Node so it returns the children of the node - in the order they should be traversed. We treat common collection types - inherently traversable (e.g. list, tuple, dict), so this method flattens and - optionally filters the arguments of the node. - - Parameters - ---------- - filter : type, default Node - Type to filter out for the traversal, Node is used by default. - - Returns - ------- - Child nodes of this node. - """ - return tuple(_flatten_collections(self.__args__, filter or Node)) - def __rich_repr__(self): """Support for rich reprerentation of the node.""" return zip(self.__argnames__, self.__args__) - def map(self, fn: Callable, filter: Optional[type] = None) -> dict[Node, Any]: + def map(self, fn: Callable, filter: Optional[Any] = None) -> dict[Node, Any]: """Apply a function to all nodes in the graph. The traversal is done in a topological order, so the function receives the @@ -149,39 +126,40 @@ def map(self, fn: Callable, filter: Optional[type] = None) -> dict[Node, Any]: Parameters ---------- - fn : Callable + fn Function to apply to each node. It receives the node as the first argument, the results as the second and the results of the children as keyword arguments. - filter : Optional[type], default None - Type to filter out for the traversal, Node is filtered out by default. + filter + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. Returns ------- A mapping of nodes to their results. """ - filter = filter or Node results: dict[Node, Any] = {} for node in Graph.from_bfs(self, filter=filter).toposort(): # minor optimization to directly recurse into the children kwargs = { - k: _recursive_get(v, results, filter) + k: _recursive_get(v, results) for k, v in zip(node.__argnames__, node.__args__) } results[node] = fn(node, results, **kwargs) return results - def find( - self, type: type | tuple[type], filter: Optional[type] = None - ) -> set[Node]: + def find(self, type: type | tuple[type], filter: Optional[Any] = None) -> set[Node]: """Find all nodes of a given type in the graph. Parameters ---------- - type : type | tuple[type] + type Type or tuple of types to find. - filter : Optional[type], default None - Type to filter out for the traversal, Node is filtered out by default. + filter + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. Returns ------- @@ -192,7 +170,7 @@ def find( @experimental def match( - self, pat: Any, filter: Optional[type] = None, context: Optional[dict] = None + self, pat: Any, filter: Optional[Any] = None, context: Optional[dict] = None ) -> set[Node]: """Find all nodes matching a given pattern in the graph. @@ -201,12 +179,14 @@ def match( Parameters ---------- - pat : Any + pat Pattern to match. `ibis.common.pattern()` function is used to coerce the input value into a pattern. See the pattern module for more details. - filter : Optional[type], default None - Type to filter out for the traversal, Node is filtered out by default. - context : Optional[dict], default None + filter + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. + context Optional context to use for the pattern matching. Returns @@ -284,17 +264,19 @@ def __init__(self, mapping=(), /, **kwargs): super().__init__(mapping, **kwargs) @classmethod - def from_bfs(cls, root: Node, filter=Node) -> Self: + def from_bfs(cls, root: Node, filter: Optional[Any] = None) -> Self: """Construct a graph from a root node using a breadth-first search. The traversal is implemented in an iterative fashion using a queue. Parameters ---------- - root : Node + root Root node of the graph. - filter : Optional[type], default None - Type to filter out for the traversal, Node is filtered out by default. + filter + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. Returns ------- @@ -306,25 +288,42 @@ def from_bfs(cls, root: Node, filter=Node) -> Self: queue = deque([root]) graph = cls() - while queue: - if (node := queue.popleft()) not in graph: - graph[node] = deps = node.__children__(filter) - queue.extend(deps) + if filter is None: + # fast path for the default no filter case, according to benchmarks + # this is gives a 10% speedup compared to the filtered version + while queue: + if (node := queue.popleft()) not in graph: + children = tuple(_flatten_collections(node.__args__)) + graph[node] = children + queue.extend(children) + else: + filter = pattern(filter) + while queue: + if (node := queue.popleft()) not in graph: + children = tuple( + child + for child in _flatten_collections(node.__args__) + if filter.match(child, {}) is not NoMatch + ) + graph[node] = children + queue.extend(children) return graph @classmethod - def from_dfs(cls, root: Node, filter=Node) -> Self: + def from_dfs(cls, root: Node, filter: Optional[Any] = None) -> Self: """Construct a graph from a root node using a depth-first search. The traversal is implemented in an iterative fashion using a stack. Parameters ---------- - root : Node + root Root node of the graph. - filter : Optional[type], default None - Type to filter out for the traversal, Node is filtered out by default. + filter + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. Returns ------- @@ -336,10 +335,25 @@ def from_dfs(cls, root: Node, filter=Node) -> Self: stack = deque([root]) graph = dict() - while stack: - if (node := stack.pop()) not in graph: - graph[node] = deps = node.__children__(filter) - stack.extend(deps) + if filter is None: + # fast path for the default no filter case, according to benchmarks + # this is gives a 10% speedup compared to the filtered version + while stack: + if (node := stack.pop()) not in graph: + children = tuple(_flatten_collections(node.__args__)) + graph[node] = children + stack.extend(children) + else: + filter = pattern(filter) + while stack: + if (node := stack.pop()) not in graph: + children = tuple( + child + for child in _flatten_collections(node.__args__) + if filter.match(child, {}) is not NoMatch + ) + graph[node] = children + stack.extend(children) return cls(reversed(graph.items())) @@ -452,7 +466,7 @@ def toposort(node: Node) -> Graph: def traverse( fn: Callable[[Node], tuple[bool | Iterable, Any]], node: Iterable[Node] | Node, - filter=Node, + filter: Optional[Any] = None, ) -> Iterator[Any]: """Utility for generic expression tree traversal. @@ -464,19 +478,25 @@ def traverse( node The Node expression or a list of expressions. filter - Restrict initial traversal to this kind of node + Pattern-like object to filter out nodes from the traversal. Essentially + the traversal will only visit nodes that match the given pattern and + stop otherwise. """ - args = reversed(node) if isinstance(node, Iterable) else [node] - todo: deque[Node] = deque(arg for arg in args if isinstance(arg, filter)) + + args = reversed(node) if isinstance(node, Sequence) else [node] + todo: deque[Node] = deque(args) seen: set[Node] = set() + filter: Pattern = pattern(filter or ...) while todo: node = todo.pop() if node in seen: continue - else: - seen.add(node) + if filter.match(node, {}) is NoMatch: + continue + + seen.add(node) control, result = fn(node) if result is not None: @@ -484,13 +504,13 @@ def traverse( if control is not halt: if control is proceed: - args = node.__children__(filter) + children = tuple(_flatten_collections(node.__args__)) elif isinstance(control, Iterable): - args = control + children = control else: raise TypeError( "First item of the returned tuple must be " "an instance of boolean or iterable" ) - todo.extend(reversed(args)) + todo.extend(reversed(children)) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 5e0a9b1b18268..97ec9d3f27442 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -210,22 +210,6 @@ def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType: """ ... - def is_match(self, value: AnyType, context: dict[str, AnyType]) -> bool: - """Check if the value matches the pattern. - - Parameters - ---------- - value - The value to match the pattern against. - context - A dictionary providing arbitrary context for the pattern matching. - - Returns - ------- - Whether the value matches the pattern. - """ - return self.match(value, context) is not NoMatch - def describe(self, plural=False): return f"matching {self!r}" @@ -1555,49 +1539,6 @@ def match(self, value, context): return dict(zip(keys, values)) -class Topmost(Slotted, Pattern): - """Traverse the value tree topmost first and match the first value that matches.""" - - __slots__ = ("pattern", "filter") - pattern: Pattern - filter: AnyType - - def __init__(self, searcher, filter=None): - super().__init__(pattern=pattern(searcher), filter=filter) - - def match(self, value, context): - result = self.pattern.match(value, context) - if result is not NoMatch: - return result - - for child in value.__children__(self.filter): - result = self.match(child, context) - if result is not NoMatch: - return result - - return NoMatch - - -class Innermost(Slotted, Pattern): - # matches items in the innermost layer first, but all matches belong to the same layer - """Traverse the value tree innermost first and match the first value that matches.""" - - __slots__ = ("pattern", "filter") - pattern: Pattern - filter: AnyType - - def __init__(self, searcher, filter=None): - super().__init__(pattern=pattern(searcher), filter=filter) - - def match(self, value, context): - for child in value.__children__(self.filter): - result = self.match(child, context) - if result is not NoMatch: - return result - - return self.pattern.match(value, context) - - def NoneOf(*args) -> Pattern: """Match none of the passed patterns.""" return Not(AnyOf(*args)) diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index a5fc3e25cb6c0..e0fcbc3b7e113 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -114,8 +114,13 @@ def test_nested_children(): c = MyNode(name="c", children=[]) d = MyNode(name="d", children=[]) e = MyNode(name="e", children=[[b, c], d]) - - assert e.__children__() == (b, c, d) + assert bfs(e) == { + e: (b, c, d), + b: (a,), + c: (), + d: (), + a: (), + } def test_example(): @@ -179,15 +184,21 @@ class All(Bool): node = All((T, F), strict=True) assert node.__args__ == ((T, F), True) - assert node.__children__() == (T, F) + assert bfs(node) == {node: (T, F), T: (), F: ()} node = Either(T, F) assert node.__args__ == (T, F) - assert node.__children__() == (T, F) + assert bfs(node) == {node: (T, F), T: (), F: ()} node = All((T, Either(T, Either(T, F))), strict=False) assert node.__args__ == ((T, Either(T, Either(T, F))), False) - assert node.__children__() == (T, Either(T, Either(T, F))) + assert bfs(node) == { + node: (T, Either(T, Either(T, F))), + T: (), + F: (), + Either(T, Either(T, F)): (T, Either(T, F)), + Either(T, F): (T, F), + } copied = node.copy(arguments=(T, F)) assert copied == All((T, F), strict=False) @@ -222,42 +233,39 @@ def test_flatten_collections(): # test that flatten collections doesn't recurse into arbitrary mappings # and sequences, just the commonly used builtin ones: list, tuple, dict - result = _flatten_collections( - [0.0, 1, 2, [3, 4, (5, 6)], "7", MySequence(8, 9)], filter=int - ) - assert list(result) == [1, 2, 3, 4, 5, 6] + result = _flatten_collections([0.0, A, B, [C, D, (E, 6)], "7", MySequence(8, A)]) + assert list(result) == [A, B, C, D, E] result = _flatten_collections( { "a": 0.0, - "b": 1, - "c": (MyMapping(d=2, e=3), frozendict(f=4)), - "d": [5, "6", {"e": (7, 8.9)}], - }, - filter=int, + "b": A, + "c": (MyMapping(d=B, e=3), frozendict(f=C)), + "d": [5, "6", {"e": (D, 8.9)}], + } ) - assert list(result) == [1, 4, 5, 7] + assert list(result) == [A, C, D] -def test_recurse_get(): - results = {"a": "A", "b": "B", "c": "C", "d": "D"} +def test_recursive_get(): + results = {A: "A", B: "B", C: "C", D: "D"} - assert _recursive_get((0, 1, "a", {"b": "c"}), results, filter=str) == ( - 0, - 1, + assert _recursive_get((A, B, "a", {"b": C}), results) == ( "A", + "B", + "a", {"b": "C"}, ) - assert _recursive_get({"a": "b", "c": "d"}, results, filter=str) == { + assert _recursive_get({"a": B, "c": D}, results) == { "a": "B", "c": "D", } - assert _recursive_get(["a", "b", "c"], results, filter=str) == ("A", "B", "C") - assert _recursive_get("a", results, filter=str) == "A" + assert _recursive_get([A, B, "c"], results) == ("A", "B", "c") + assert _recursive_get(A, results) == "A" - my_seq = MySequence("a", "b", "c") - my_map = MyMapping(a="a", b="b", c="c") - assert _recursive_get(("a", my_seq, ["b", "a"], my_map), results, filter=str) == ( + my_seq = MySequence(A, "b", "c") + my_map = MyMapping(a="a", b=B, c="c") + assert _recursive_get((A, my_seq, [B, A], my_map), results) == ( "A", my_seq, ("B", "A"), diff --git a/ibis/common/tests/test_graph_benchmarks.py b/ibis/common/tests/test_graph_benchmarks.py index 12529f894f551..f0d11ff5c9ed1 100644 --- a/ibis/common/tests/test_graph_benchmarks.py +++ b/ibis/common/tests/test_graph_benchmarks.py @@ -1,19 +1,49 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Optional -from ibis.common.collections import frozendict # noqa: TCH001 -from ibis.common.graph import Node -from ibis.common.grounds import Concrete +from typing_extensions import Self # noqa: TCH002 -if TYPE_CHECKING: - from typing_extensions import Self +from ibis.common.collections import frozendict +from ibis.common.graph import Graph, Node +from ibis.common.grounds import Concrete -class MyNode(Node, Concrete): +class MyNode(Concrete, Node): a: int b: str c: tuple[int, ...] d: frozendict[str, int] - e: Self - f: tuple[Self, ...] + e: Optional[Self] = None + f: tuple[Self, ...] = () + + +def generate_node(depth): + # generate a nested node object with the given depth + if depth == 0: + return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60)) + return MyNode( + 1, + "2", + c=(3, 4), + d=frozendict(e=5, f=6), + e=generate_node(0), + f=(generate_node(depth - 1), generate_node(0)), + ) + + +def test_generate_node(): + for depth in [0, 1, 2, 10, 100]: + n = generate_node(depth) + assert isinstance(n, MyNode) + assert len(Graph.from_bfs(n).nodes()) == depth + 1 + + +def test_bfs(benchmark): + node = generate_node(500) + benchmark(Graph.from_bfs, node) + + +def test_dfs(benchmark): + node = generate_node(500) + benchmark(Graph.from_dfs, node) diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index c6b4ebe54abdc..091230bc8e29c 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -43,7 +43,6 @@ FrozenDictOf, GenericInstanceOf, GenericSequenceOf, - Innermost, InstanceOf, IsIn, LazyInstanceOf, @@ -63,7 +62,6 @@ Replace, SequenceOf, SubclassOf, - Topmost, TupleOf, TypeOf, Variable, @@ -1243,22 +1241,6 @@ class Mul(Binary): fourteen = Add(seven, seven) -def test_topmost_innermost(): - inner = Object(Mul, Capture("a"), Capture("b")) - assert inner.match(six, {}) is six - - context = {} - p = Topmost(inner) - m = p.match(seven, context) - assert m is six - assert context == {"a": two, "b": three} - - p = Innermost(inner) - m = p.match(seven, context) - assert m is two - assert context == {"a": Lit(2), "b": one} - - def test_node(): pat = Node( InstanceOf(Add),