Skip to content

Commit

Permalink
feat(common): add Node.find_below() methods to exclude the root nod…
Browse files Browse the repository at this point in the history
…e from filtering (#8861)
  • Loading branch information
kszucs authored Apr 2, 2024
1 parent a5de9ed commit 80d12a2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
66 changes: 41 additions & 25 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ibis.common.collections import frozendict
from ibis.common.patterns import NoMatch, Pattern
from ibis.common.typing import _ClassInfo
from ibis.util import experimental
from ibis.util import experimental, promote_list

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -340,9 +340,39 @@ def find(
determined by a breadth-first search.
"""
nodes = Graph.from_bfs(self, filter=filter, context=context).nodes()
graph = Graph.from_bfs(self, filter=filter, context=context)
finder = _coerce_finder(finder, context)
return [node for node in nodes if finder(node)]
return [node for node in graph.nodes() if finder(node)]

@experimental
def find_below(
self,
finder: FinderLike,
filter: Optional[FinderLike] = None,
context: Optional[dict] = None,
) -> list[Node]:
"""Find all nodes below the current node matching a given pattern in the graph.
A variant of find() that only returns nodes below the current node in the graph.
Parameters
----------
finder
A type, tuple of types, a pattern or a callable to match upon.
filter
A type, tuple of types, a pattern or a callable to filter out nodes
from the traversal. The traversal will only visit nodes that match
the given filter and stop otherwise.
context
Optional context to use if `finder` or `filter` is a pattern.
Returns
-------
The list of nodes matching the given pattern.
"""
graph = Graph.from_bfs(self.__children__, filter=filter, context=context)
finder = _coerce_finder(finder, context)
return [node for node in graph.nodes() if finder(node)]

@experimental
def find_topmost(
Expand Down Expand Up @@ -620,10 +650,8 @@ def bfs(root: Node) -> Graph:
"""
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

queue = deque([root])
nodes = _flatten_collections(promote_list(root))
queue = deque(nodes)
graph = Graph()

while queue:
Expand Down Expand Up @@ -651,15 +679,10 @@ def bfs_while(root: Node, filter: Finder) -> Graph:
A graph constructed from the root node.
"""
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

queue = deque()
nodes = _flatten_collections(promote_list(root))
queue = deque(node for node in nodes if filter(node))
graph = Graph()

if filter(root):
queue.append(root)

while queue:
if (node := queue.popleft()) not in graph:
children = tuple(child for child in node.__children__ if filter(child))
Expand All @@ -684,10 +707,8 @@ def dfs(root: Node) -> Graph:
"""
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

stack = deque([root])
nodes = _flatten_collections(promote_list(root))
stack = deque(nodes)
graph = {}

while stack:
Expand Down Expand Up @@ -715,15 +736,10 @@ def dfs_while(root: Node, filter: Finder) -> Graph:
A graph constructed from the root node.
"""
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

stack = deque()
nodes = _flatten_collections(promote_list(root))
stack = deque(node for node in nodes if filter(node))
graph = {}

if filter(root):
stack.append(root)

while stack:
if (node := stack.pop()) not in graph:
children = tuple(child for child in node.__children__ if filter(child))
Expand Down
24 changes: 14 additions & 10 deletions ibis/common/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ def copy(self, name=None, children=None):

def test_bfs():
assert list(bfs(A).keys()) == [A, B, C, D, E]

with pytest.raises(
TypeError, match="must be an instance of ibis.common.graph.Node"
):
bfs(1)
assert list(bfs([D, E, B])) == [D, E, B]
assert bfs(1) == {}


def test_construction():
Expand All @@ -82,11 +79,8 @@ def test_graph_repr():

def test_dfs():
assert list(dfs(A).keys()) == [D, E, B, C, A]

with pytest.raises(
TypeError, match="must be an instance of ibis.common.graph.Node"
):
dfs(1)
assert list(dfs([D, E, B])) == [D, E, B]
assert dfs(1) == {}


def test_invert():
Expand Down Expand Up @@ -393,6 +387,16 @@ def test_node_find_using_pattern():
assert result == [A, B]


def test_node_find_below():
lowercase = MyNode(name="lowercase", children=[])
root = MyNode(name="root", children=[A, B, lowercase])
result = root.find_below(MyNode)
assert result == [A, B, lowercase, C, D, E]

result = root.find_below(lambda x: x.name.islower(), filter=lambda x: x != root)
assert result == [lowercase]


def test_node_find_topmost_using_type():
class FooNode(MyNode):
pass
Expand Down

0 comments on commit 80d12a2

Please sign in to comment.