diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index fc26556b0e4..2817effa856 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -56,6 +56,9 @@ def check_isomorphic( Also optionally raised if their structure is isomorphic, but the names of any two respective nodes are not equal. """ + # TODO: remove require_names_equal and check_from_root. Instead, check that + # all child nodes match, in any order, which will suffice once + # map_over_datasets switches to use zip_subtrees. if not isinstance(a, TreeNode): raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}") @@ -68,7 +71,7 @@ def check_isomorphic( diff = diff_treestructure(a, b, require_names_equal=require_names_equal) - if diff: + if diff is not None: raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 570892bcb6b..5ef3b9924a0 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -21,7 +21,6 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_equiv, astype from xarray.core.indexing import MemoryCachedArray -from xarray.core.iterators import LevelOrderIter from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.utils import is_duck_array from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy @@ -981,16 +980,28 @@ def diff_array_repr(a, b, compat): return "\n".join(summary) -def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: +def diff_treestructure( + a: DataTree, b: DataTree, require_names_equal: bool +) -> str | None: """ Return a summary of why two trees are not isomorphic. - If they are isomorphic return an empty string. + If they are isomorphic return None. """ + # .subtrees walks nodes in breadth-first-order, in order to produce as + # shallow of a diff as possible + + # TODO: switch zip(a.subtree, b.subtree) to zip_subtrees(a, b), and only + # check that child node names match, e.g., + # for node_a, node_b in zip_subtrees(a, b): + # if node_a.children.keys() != node_b.children.keys(): + # diff = dedent( + # f"""\ + # Node {node_a.path!r} in the left object has children {list(node_a.children.keys())} + # Node {node_b.path!r} in the right object has children {list(node_b.children.keys())}""" + # ) + # return diff - # Walking nodes in "level-order" fashion means walking down from the root breadth-first. - # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree - # (which it is so long as children are stored in a tuple or list rather than in a set). - for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True): + for node_a, node_b in zip(a.subtree, b.subtree, strict=True): path_a, path_b = node_a.path, node_b.path if require_names_equal and node_a.name != node_b.name: @@ -1009,7 +1020,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s ) return diff - return "" + return None def diff_dataset_repr(a, b, compat): @@ -1063,9 +1074,9 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): # If the trees structures are different there is no point comparing each node # TODO we could show any differences in nodes up to the first place that structure differs? - if treestructure_diff or compat == "isomorphic": + if treestructure_diff is not None: summary.append("\n" + treestructure_diff) - else: + elif compat != "isomorphic": nodewise_diff = diff_nodewise_summary(a, b, compat) summary.append("\n" + nodewise_diff) diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py deleted file mode 100644 index eeaeb35aa9c..00000000000 --- a/xarray/core/iterators.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Iterator - -from xarray.core.treenode import Tree - -"""These iterators are copied from anytree.iterators, with minor modifications.""" - - -class LevelOrderIter(Iterator): - """Iterate over tree applying level-order strategy starting at `node`. - This is the iterator used by `DataTree` to traverse nodes. - - Parameters - ---------- - node : Tree - Node in a tree to begin iteration at. - filter_ : Callable, optional - Function called with every `node` as argument, `node` is returned if `True`. - Default is to iterate through all ``node`` objects in the tree. - stop : Callable, optional - Function that will cause iteration to stop if ``stop`` returns ``True`` - for ``node``. - maxlevel : int, optional - Maximum level to descend in the node hierarchy. - - Examples - -------- - >>> from xarray.core.datatree import DataTree - >>> from xarray.core.iterators import LevelOrderIter - >>> f = DataTree.from_dict( - ... {"/b/a": None, "/b/d/c": None, "/b/d/e": None, "/g/h/i": None}, name="f" - ... ) - >>> print(f) - - Group: / - ├── Group: /b - │ ├── Group: /b/a - │ └── Group: /b/d - │ ├── Group: /b/d/c - │ └── Group: /b/d/e - └── Group: /g - └── Group: /g/h - └── Group: /g/h/i - >>> [node.name for node in LevelOrderIter(f)] - ['f', 'b', 'g', 'a', 'd', 'h', 'c', 'e', 'i'] - >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] - ['f', 'b', 'g', 'a', 'd', 'h'] - >>> [ - ... node.name - ... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g")) - ... ] - ['f', 'b', 'a', 'd', 'h', 'c', 'i'] - >>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")] - ['f', 'b', 'g', 'a', 'h', 'i'] - """ - - def __init__( - self, - node: Tree, - filter_: Callable | None = None, - stop: Callable | None = None, - maxlevel: int | None = None, - ): - self.node = node - self.filter_ = filter_ - self.stop = stop - self.maxlevel = maxlevel - self.__iter = None - - def __init(self): - node = self.node - maxlevel = self.maxlevel - filter_ = self.filter_ or LevelOrderIter.__default_filter - stop = self.stop or LevelOrderIter.__default_stop - children = ( - [] - if LevelOrderIter._abort_at_level(1, maxlevel) - else LevelOrderIter._get_children([node], stop) - ) - return self._iter(children, filter_, stop, maxlevel) - - @staticmethod - def __default_filter(node: Tree) -> bool: - return True - - @staticmethod - def __default_stop(node: Tree) -> bool: - return False - - def __iter__(self) -> Iterator[Tree]: - return self - - def __next__(self) -> Iterator[Tree]: - if self.__iter is None: - self.__iter = self.__init() - item = next(self.__iter) # type: ignore[call-overload] - return item - - @staticmethod - def _abort_at_level(level: int, maxlevel: int | None) -> bool: - return maxlevel is not None and level > maxlevel - - @staticmethod - def _get_children(children: list[Tree], stop: Callable) -> list[Tree]: - return [child for child in children if not stop(child)] - - @staticmethod - def _iter( - children: list[Tree], filter_: Callable, stop: Callable, maxlevel: int | None - ) -> Iterator[Tree]: - level = 1 - while children: - next_children = [] - for child in children: - if filter_(child): - yield child - next_children += LevelOrderIter._get_children( - list(child.children.values()), stop - ) - children = next_children - level += 1 - if LevelOrderIter._abort_at_level(level, maxlevel): - break diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index b28b66cf48e..30646d476f8 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections import sys from collections.abc import Iterator, Mapping from pathlib import PurePosixPath @@ -400,15 +401,18 @@ def subtree(self: Tree) -> Iterator[Tree]: """ An iterator over all nodes in this tree, including both self and all descendants. - Iterates depth-first. + Iterates breadth-first. See Also -------- DataTree.descendants """ - from xarray.core.iterators import LevelOrderIter - - return LevelOrderIter(self) + # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode + queue = collections.deque([self]) + while queue: + node = queue.popleft() + yield node + queue.extend(node.children.values()) @property def descendants(self: Tree) -> tuple[Tree, ...]: @@ -772,3 +776,44 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: generation_gap = list(parents_paths).index(ancestor.path) path_upwards = "../" * generation_gap if generation_gap > 0 else "." return NodePath(path_upwards) + + +def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]: + """Iterate over aligned subtrees in breadth-first order. + + Parameters: + ----------- + *trees : Tree + Trees to iterate over. + + Yields + ------ + Tuples of matching subtrees. + """ + if not trees: + raise TypeError("must pass at least one tree object") + + # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode + queue = collections.deque([trees]) + + while queue: + active_nodes = queue.popleft() + + # yield before raising an error, in case the caller chooses to exit + # iteration early + yield active_nodes + + first_node = active_nodes[0] + if any( + sibling.children.keys() != first_node.children.keys() + for sibling in active_nodes[1:] + ): + child_summary = " vs ".join( + str(list(node.children)) for node in active_nodes + ) + raise ValueError( + f"children at {first_node.path!r} do not match: {child_summary}" + ) + + for name in first_node.children: + queue.append(tuple(node.children[name] for node in active_nodes)) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 1db9c594247..6a50d8ec8e5 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -1,21 +1,25 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import cast +import re import pytest -from xarray.core.iterators import LevelOrderIter -from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode +from xarray.core.treenode import ( + InvalidTreeError, + NamedNode, + NodePath, + TreeNode, + zip_subtrees, +) class TestFamilyTree: - def test_lonely(self): + def test_lonely(self) -> None: root: TreeNode = TreeNode() assert root.parent is None assert root.children == {} - def test_parenting(self): + def test_parenting(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") @@ -23,7 +27,7 @@ def test_parenting(self): assert mary.parent == john assert john.children["Mary"] is mary - def test_no_time_traveller_loops(self): + def test_no_time_traveller_loops(self) -> None: john: TreeNode = TreeNode() with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): @@ -43,7 +47,7 @@ def test_no_time_traveller_loops(self): with pytest.raises(InvalidTreeError, match="is already a descendant"): rose.children = {"John": john} - def test_parent_swap(self): + def test_parent_swap(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") @@ -55,7 +59,7 @@ def test_parent_swap(self): assert steve.children["Mary"] is mary assert "Mary" not in john.children - def test_forbid_setting_parent_directly(self): + def test_forbid_setting_parent_directly(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() @@ -64,13 +68,13 @@ def test_forbid_setting_parent_directly(self): ): mary.parent = john - def test_dont_modify_children_inplace(self): + def test_dont_modify_children_inplace(self) -> None: # GH issue 9196 child: TreeNode = TreeNode() TreeNode(children={"child": child}) assert child.parent is None - def test_multi_child_family(self): + def test_multi_child_family(self) -> None: john: TreeNode = TreeNode(children={"Mary": TreeNode(), "Kate": TreeNode()}) assert "Mary" in john.children @@ -83,14 +87,14 @@ def test_multi_child_family(self): assert isinstance(kate, TreeNode) assert kate.parent is john - def test_disown_child(self): + def test_disown_child(self) -> None: john: TreeNode = TreeNode(children={"Mary": TreeNode()}) mary = john.children["Mary"] mary.orphan() assert mary.parent is None assert "Mary" not in john.children - def test_doppelganger_child(self): + def test_doppelganger_child(self) -> None: kate: TreeNode = TreeNode() john: TreeNode = TreeNode() @@ -105,7 +109,7 @@ def test_doppelganger_child(self): evil_kate._set_parent(john, "Kate") assert john.children["Kate"] is evil_kate - def test_sibling_relationships(self): + def test_sibling_relationships(self) -> None: john: TreeNode = TreeNode( children={"Mary": TreeNode(), "Kate": TreeNode(), "Ashley": TreeNode()} ) @@ -113,7 +117,7 @@ def test_sibling_relationships(self): assert list(kate.siblings) == ["Mary", "Ashley"] assert "Kate" not in kate.siblings - def test_copy_subtree(self): + def test_copy_subtree(self) -> None: tony: TreeNode = TreeNode() michael: TreeNode = TreeNode(children={"Tony": tony}) vito = TreeNode(children={"Michael": michael}) @@ -122,7 +126,7 @@ def test_copy_subtree(self): copied_tony = vito.children["Michael"].children["Tony"] assert copied_tony is not tony - def test_parents(self): + def test_parents(self) -> None: vito: TreeNode = TreeNode( children={"Michael": TreeNode(children={"Tony": TreeNode()})}, ) @@ -134,7 +138,7 @@ def test_parents(self): class TestGetNodes: - def test_get_child(self): + def test_get_child(self) -> None: john: TreeNode = TreeNode( children={ "Mary": TreeNode( @@ -163,7 +167,7 @@ def test_get_child(self): # get from middle of tree assert mary._get_item("Sue/Steven") is steven - def test_get_upwards(self): + def test_get_upwards(self) -> None: john: TreeNode = TreeNode( children={ "Mary": TreeNode(children={"Sue": TreeNode(), "Kate": TreeNode()}) @@ -179,7 +183,7 @@ def test_get_upwards(self): # relative path assert sue._get_item("../Kate") is kate - def test_get_from_root(self): + def test_get_from_root(self) -> None: john: TreeNode = TreeNode( children={"Mary": TreeNode(children={"Sue": TreeNode()})} ) @@ -190,7 +194,7 @@ def test_get_from_root(self): class TestSetNodes: - def test_set_child_node(self): + def test_set_child_node(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() john._set_item("Mary", mary) @@ -200,14 +204,14 @@ def test_set_child_node(self): assert mary.children == {} assert mary.parent is john - def test_child_already_exists(self): + def test_child_already_exists(self) -> None: mary: TreeNode = TreeNode() john: TreeNode = TreeNode(children={"Mary": mary}) mary_2: TreeNode = TreeNode() with pytest.raises(KeyError): john._set_item("Mary", mary_2, allow_overwrite=False) - def test_set_grandchild(self): + def test_set_grandchild(self) -> None: rose: TreeNode = TreeNode() mary: TreeNode = TreeNode() john: TreeNode = TreeNode() @@ -220,7 +224,7 @@ def test_set_grandchild(self): assert "Rose" in mary.children assert rose.parent is mary - def test_create_intermediate_child(self): + def test_create_intermediate_child(self) -> None: john: TreeNode = TreeNode() rose: TreeNode = TreeNode() @@ -237,7 +241,7 @@ def test_create_intermediate_child(self): assert rose.parent == mary assert rose.parent == mary - def test_overwrite_child(self): + def test_overwrite_child(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() john._set_item("Mary", mary) @@ -257,7 +261,7 @@ def test_overwrite_child(self): class TestPruning: - def test_del_child(self): + def test_del_child(self) -> None: john: TreeNode = TreeNode() mary: TreeNode = TreeNode() john._set_item("Mary", mary) @@ -299,15 +303,12 @@ def create_test_tree() -> tuple[NamedNode, NamedNode]: return a, f -class TestIterators: +class TestZipSubtrees: - def test_levelorderiter(self): + def test_one_tree(self) -> None: root, _ = create_test_tree() - result: list[str | None] = [ - node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) - ] expected = [ - "a", # root Node is unnamed + "a", "b", "c", "d", @@ -317,23 +318,52 @@ def test_levelorderiter(self): "g", "i", ] + result = [node[0].name for node in zip_subtrees(root)] assert result == expected + def test_different_order(self) -> None: + first: NamedNode = NamedNode( + name="a", children={"b": NamedNode(), "c": NamedNode()} + ) + second: NamedNode = NamedNode( + name="a", children={"c": NamedNode(), "b": NamedNode()} + ) + assert [node.name for node in first.subtree] == ["a", "b", "c"] + assert [node.name for node in second.subtree] == ["a", "c", "b"] + assert [(x.name, y.name) for x, y in zip_subtrees(first, second)] == [ + ("a", "a"), + ("b", "b"), + ("c", "c"), + ] + + def test_different_structure(self) -> None: + first: NamedNode = NamedNode(name="a", children={"b": NamedNode()}) + second: NamedNode = NamedNode(name="a", children={"c": NamedNode()}) + it = zip_subtrees(first, second) + + x, y = next(it) + assert x.name == y.name == "a" + + with pytest.raises( + ValueError, match=re.escape(r"children at '/' do not match: ['b'] vs ['c']") + ): + next(it) + class TestAncestry: - def test_parents(self): + def test_parents(self) -> None: _, leaf_f = create_test_tree() expected = ["e", "b", "a"] assert [node.name for node in leaf_f.parents] == expected - def test_lineage(self): + def test_lineage(self) -> None: _, leaf_f = create_test_tree() expected = ["f", "e", "b", "a"] with pytest.warns(DeprecationWarning): assert [node.name for node in leaf_f.lineage] == expected - def test_ancestors(self): + def test_ancestors(self) -> None: _, leaf_f = create_test_tree() with pytest.warns(DeprecationWarning): ancestors = leaf_f.ancestors @@ -341,9 +371,8 @@ def test_ancestors(self): for node, expected_name in zip(ancestors, expected, strict=True): assert node.name == expected_name - def test_subtree(self): + def test_subtree(self) -> None: root, _ = create_test_tree() - subtree = root.subtree expected = [ "a", "b", @@ -355,10 +384,10 @@ def test_subtree(self): "g", "i", ] - for node, expected_name in zip(subtree, expected, strict=True): - assert node.name == expected_name + actual = [node.name for node in root.subtree] + assert expected == actual - def test_descendants(self): + def test_descendants(self) -> None: root, _ = create_test_tree() descendants = root.descendants expected = [ @@ -374,7 +403,7 @@ def test_descendants(self): for node, expected_name in zip(descendants, expected, strict=True): assert node.name == expected_name - def test_leaves(self): + def test_leaves(self) -> None: tree, _ = create_test_tree() leaves = tree.leaves expected = [ @@ -386,7 +415,7 @@ def test_leaves(self): for node, expected_name in zip(leaves, expected, strict=True): assert node.name == expected_name - def test_levels(self): + def test_levels(self) -> None: a, f = create_test_tree() assert a.level == 0 @@ -400,7 +429,7 @@ def test_levels(self): class TestRenderTree: - def test_render_nodetree(self): + def test_render_nodetree(self) -> None: john: NamedNode = NamedNode( children={ "Mary": NamedNode(children={"Sam": NamedNode(), "Ben": NamedNode()}),