Skip to content

Commit

Permalink
[ir 186/186] perf(common): improve the performance of replacing nodes…
Browse files Browse the repository at this point in the history
… by using a specialized `__recreate__` method
  • Loading branch information
kszucs committed Dec 14, 2023
1 parent 12016ed commit bf841d7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
10 changes: 8 additions & 2 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def fn(node, _, **kwargs):
# need to first reconstruct the node from the possible rewritten
# children, so we can match on the new node containing the rewritten
# child arguments, this way we can propagate the rewritten nodes
# upward in the hierarchy
recreated = node.__class__(**kwargs)
# upward in the hierarchy, using a specialized __recreate__ method
# improves the performance by 17% compared node.__class__(**kwargs)
recreated = node.__recreate__(kwargs)
if (result := obj.match(recreated, ctx)) is NoMatch:
return recreated
else:
Expand All @@ -172,6 +173,11 @@ def fn(node, _, **kwargs):
class Node(Hashable):
__slots__ = ()

@classmethod
def __recreate__(cls, kwargs: Any) -> Self:
"""Reconstruct the node from the given arguments."""
return cls(**kwargs)

@property
@abstractmethod
def __args__(self) -> tuple[Any, ...]:
Expand Down
10 changes: 9 additions & 1 deletion ibis/common/tests/test_graph_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing_extensions import Self # noqa: TCH002

from ibis.common.collections import frozendict
from ibis.common.deferred import _
from ibis.common.graph import Graph, Node
from ibis.common.grounds import Concrete
from ibis.common.patterns import Between, Object


class MyNode(Concrete, Node):
Expand All @@ -24,7 +26,7 @@ def generate_node(depth):
if depth == 0:
return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60))
return MyNode(
1,
depth,
"2",
c=(3, 4),
d=frozendict(e=5, f=6),
Expand All @@ -48,3 +50,9 @@ def test_bfs(benchmark):
def test_dfs(benchmark):
node = generate_node(500)
benchmark(Graph.from_dfs, node)


def test_replace(benchmark):
node = generate_node(500)
pattern = Object(MyNode, a=Between(lower=100)) >> _.copy(a=_.a + 1)
benchmark(node.replace, pattern)

0 comments on commit bf841d7

Please sign in to comment.