Skip to content

Commit

Permalink
Add zip_subtrees for paired iteration over DataTrees (#9623)
Browse files Browse the repository at this point in the history
* Add zip_subtrees for paired iteration over DataTrees

This should be used for implementing DataTree arithmetic inside
map_over_datasets, so the result does not depend on the order in which
child nodes are defined.

I have also added a minimal implementation of breadth-first-search with
an explicit queue the current recursion based solution in
xarray.core.iterators (which has been removed). The new implementation
is also slightly faster in my microbenchmark:

    In [1]: import xarray as xr

    In [2]: tree = xr.DataTree.from_dict({f"/x{i}": None for i in range(100)})

    In [3]: %timeit _ = list(tree.subtree)
    # on main
    87.2 μs ± 394 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

    # with this branch
    55.1 μs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

* fix pytype error

* Tweaks per review
  • Loading branch information
shoyer authored Oct 16, 2024
1 parent 88a95cf commit 0c1d02e
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 181 deletions.
5 changes: 4 additions & 1 deletion xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def check_isomorphic(
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
# TODO: remove require_names_equal and check_from_root. Instead, check that
# all child nodes match, in any order, which will suffice once
# map_over_datasets switches to use zip_subtrees.

if not isinstance(a, TreeNode):
raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
Expand All @@ -68,7 +71,7 @@ def check_isomorphic(

diff = diff_treestructure(a, b, require_names_equal=require_names_equal)

if diff:
if diff is not None:
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


Expand Down
31 changes: 21 additions & 10 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.iterators import LevelOrderIter
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.utils import is_duck_array
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
Expand Down Expand Up @@ -981,16 +980,28 @@ def diff_array_repr(a, b, compat):
return "\n".join(summary)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
def diff_treestructure(
a: DataTree, b: DataTree, require_names_equal: bool
) -> str | None:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
If they are isomorphic return None.
"""
# .subtrees walks nodes in breadth-first-order, in order to produce as
# shallow of a diff as possible

# TODO: switch zip(a.subtree, b.subtree) to zip_subtrees(a, b), and only
# check that child node names match, e.g.,
# for node_a, node_b in zip_subtrees(a, b):
# if node_a.children.keys() != node_b.children.keys():
# diff = dedent(
# f"""\
# Node {node_a.path!r} in the left object has children {list(node_a.children.keys())}
# Node {node_b.path!r} in the right object has children {list(node_b.children.keys())}"""
# )
# return diff

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True):
for node_a, node_b in zip(a.subtree, b.subtree, strict=True):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
Expand All @@ -1009,7 +1020,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
)
return diff

return ""
return None


def diff_dataset_repr(a, b, compat):
Expand Down Expand Up @@ -1063,9 +1074,9 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):

# If the trees structures are different there is no point comparing each node
# TODO we could show any differences in nodes up to the first place that structure differs?
if treestructure_diff or compat == "isomorphic":
if treestructure_diff is not None:
summary.append("\n" + treestructure_diff)
else:
elif compat != "isomorphic":
nodewise_diff = diff_nodewise_summary(a, b, compat)
summary.append("\n" + nodewise_diff)

Expand Down
124 changes: 0 additions & 124 deletions xarray/core/iterators.py

This file was deleted.

53 changes: 49 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
Expand Down Expand Up @@ -400,15 +401,18 @@ def subtree(self: Tree) -> Iterator[Tree]:
"""
An iterator over all nodes in this tree, including both self and all descendants.
Iterates depth-first.
Iterates breadth-first.
See Also
--------
DataTree.descendants
"""
from xarray.core.iterators import LevelOrderIter

return LevelOrderIter(self)
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([self])
while queue:
node = queue.popleft()
yield node
queue.extend(node.children.values())

@property
def descendants(self: Tree) -> tuple[Tree, ...]:
Expand Down Expand Up @@ -772,3 +776,44 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath:
generation_gap = list(parents_paths).index(ancestor.path)
path_upwards = "../" * generation_gap if generation_gap > 0 else "."
return NodePath(path_upwards)


def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]:
"""Iterate over aligned subtrees in breadth-first order.
Parameters:
-----------
*trees : Tree
Trees to iterate over.
Yields
------
Tuples of matching subtrees.
"""
if not trees:
raise TypeError("must pass at least one tree object")

# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([trees])

while queue:
active_nodes = queue.popleft()

# yield before raising an error, in case the caller chooses to exit
# iteration early
yield active_nodes

first_node = active_nodes[0]
if any(
sibling.children.keys() != first_node.children.keys()
for sibling in active_nodes[1:]
):
child_summary = " vs ".join(
str(list(node.children)) for node in active_nodes
)
raise ValueError(
f"children at {first_node.path!r} do not match: {child_summary}"
)

for name in first_node.children:
queue.append(tuple(node.children[name] for node in active_nodes))
Loading

0 comments on commit 0c1d02e

Please sign in to comment.