diff --git a/ibis/common/graph.py b/ibis/common/graph.py index c2aaf5d7e046..f2f810077716 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -11,7 +11,7 @@ from ibis.common.collections import frozendict from ibis.common.patterns import NoMatch, Pattern from ibis.common.typing import _ClassInfo -from ibis.util import experimental +from ibis.util import experimental, promote_list if TYPE_CHECKING: from typing_extensions import Self @@ -340,9 +340,39 @@ def find( determined by a breadth-first search. """ - nodes = Graph.from_bfs(self, filter=filter, context=context).nodes() + graph = Graph.from_bfs(self, filter=filter, context=context) finder = _coerce_finder(finder, context) - return [node for node in nodes if finder(node)] + return [node for node in graph.nodes() if finder(node)] + + @experimental + def find_below( + self, + finder: FinderLike, + filter: Optional[FinderLike] = None, + context: Optional[dict] = None, + ) -> list[Node]: + """Find all nodes below the current node matching a given pattern in the graph. + + A variant of find() that only returns nodes below the current node in the graph. + + Parameters + ---------- + finder + A type, tuple of types, a pattern or a callable to match upon. + filter + A type, tuple of types, a pattern or a callable to filter out nodes + from the traversal. The traversal will only visit nodes that match + the given filter and stop otherwise. + context + Optional context to use if `finder` or `filter` is a pattern. + + Returns + ------- + The list of nodes matching the given pattern. + """ + graph = Graph.from_bfs(self.__children__, filter=filter, context=context) + finder = _coerce_finder(finder, context) + return [node for node in graph.nodes() if finder(node)] @experimental def find_topmost( @@ -620,10 +650,8 @@ def bfs(root: Node) -> Graph: """ # fast path for the default no filter case, according to benchmarks # this is gives a 10% speedup compared to the filtered version - if not isinstance(root, Node): - raise TypeError("node must be an instance of ibis.common.graph.Node") - - queue = deque([root]) + nodes = _flatten_collections(promote_list(root)) + queue = deque(nodes) graph = Graph() while queue: @@ -651,15 +679,10 @@ def bfs_while(root: Node, filter: Finder) -> Graph: A graph constructed from the root node. """ - if not isinstance(root, Node): - raise TypeError("node must be an instance of ibis.common.graph.Node") - - queue = deque() + nodes = _flatten_collections(promote_list(root)) + queue = deque(node for node in nodes if filter(node)) graph = Graph() - if filter(root): - queue.append(root) - while queue: if (node := queue.popleft()) not in graph: children = tuple(child for child in node.__children__ if filter(child)) @@ -684,10 +707,8 @@ def dfs(root: Node) -> Graph: """ # fast path for the default no filter case, according to benchmarks # this is gives a 10% speedup compared to the filtered version - if not isinstance(root, Node): - raise TypeError("node must be an instance of ibis.common.graph.Node") - - stack = deque([root]) + nodes = _flatten_collections(promote_list(root)) + stack = deque(nodes) graph = {} while stack: @@ -715,15 +736,10 @@ def dfs_while(root: Node, filter: Finder) -> Graph: A graph constructed from the root node. """ - if not isinstance(root, Node): - raise TypeError("node must be an instance of ibis.common.graph.Node") - - stack = deque() + nodes = _flatten_collections(promote_list(root)) + stack = deque(node for node in nodes if filter(node)) graph = {} - if filter(root): - stack.append(root) - while stack: if (node := stack.pop()) not in graph: children = tuple(child for child in node.__children__ if filter(child)) diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index 677f9f0f204d..b0da3bdf2a0f 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -59,11 +59,8 @@ def copy(self, name=None, children=None): def test_bfs(): assert list(bfs(A).keys()) == [A, B, C, D, E] - - with pytest.raises( - TypeError, match="must be an instance of ibis.common.graph.Node" - ): - bfs(1) + assert list(bfs([D, E, B])) == [D, E, B] + assert bfs(1) == {} def test_construction(): @@ -82,11 +79,8 @@ def test_graph_repr(): def test_dfs(): assert list(dfs(A).keys()) == [D, E, B, C, A] - - with pytest.raises( - TypeError, match="must be an instance of ibis.common.graph.Node" - ): - dfs(1) + assert list(dfs([D, E, B])) == [D, E, B] + assert dfs(1) == {} def test_invert(): @@ -393,6 +387,16 @@ def test_node_find_using_pattern(): assert result == [A, B] +def test_node_find_below(): + lowercase = MyNode(name="lowercase", children=[]) + root = MyNode(name="root", children=[A, B, lowercase]) + result = root.find_below(MyNode) + assert result == [A, B, lowercase, C, D, E] + + result = root.find_below(lambda x: x.name.islower(), filter=lambda x: x != root) + assert result == [lowercase] + + def test_node_find_topmost_using_type(): class FooNode(MyNode): pass