diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1d25b1bf88c..1c1b5042caa 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1104,10 +1104,12 @@ def from_dict( d : dict-like A mapping from path names to xarray.Dataset or DataTree objects. - Path names are to be given as unix-like path. If path names containing more than one - part are given, new tree nodes will be constructed as necessary. + Path names are to be given as unix-like path. If path names + containing more than one part are given, new tree nodes will be + constructed as necessary. - To assign data to the root node of the tree use "/" as the path. + To assign data to the root node of the tree use "", ".", "/" or "./" + as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. @@ -1119,17 +1121,27 @@ def from_dict( ----- If your dictionary is nested you will need to flatten it before using this method. """ - - # First create the root node + # Find any values corresponding to the root d_cast = dict(d) - root_data = d_cast.pop("/", None) + root_data = None + for key in ("", ".", "/", "./"): + if key in d_cast: + if root_data is not None: + raise ValueError( + "multiple entries found corresponding to the root node" + ) + root_data = d_cast.pop(key) + + # Create the root node if isinstance(root_data, DataTree): obj = root_data.copy() + obj.name = name elif root_data is None or isinstance(root_data, Dataset): obj = cls(name=name, dataset=root_data, children=None) else: raise TypeError( - f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}' + f'root node data (at "", ".", "/" or "./") must be a Dataset ' + f"or DataTree, got {type(root_data)}" ) def depth(item) -> int: @@ -1141,11 +1153,10 @@ def depth(item) -> int: # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) for path, data in sorted(d_cast.items(), key=depth): # Create and set new node - node_name = NodePath(path).name if isinstance(data, DataTree): new_node = data.copy() elif isinstance(data, Dataset) or data is None: - new_node = cls(name=node_name, dataset=data) + new_node = cls(dataset=data) else: raise TypeError(f"invalid values: {data}") obj._set_item( @@ -1683,7 +1694,7 @@ def reduce( numeric_only=numeric_only, **kwargs, ) - path = "/" if node is self else node.relative_to(self) + path = node.relative_to(self) result[path] = node_result return type(self).from_dict(result, name=self.name) @@ -1718,7 +1729,7 @@ def _selective_indexing( # with a scalar) can also create scalar coordinates, which # need to be explicitly removed. del node_result.coords[k] - path = "/" if node is self else node.relative_to(self) + path = node.relative_to(self) result[path] = node_result return type(self).from_dict(result, name=self.name) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index a710fbfafa0..686b45968a0 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -883,6 +883,48 @@ def test_array_values(self) -> None: with pytest.raises(TypeError): DataTree.from_dict(data) # type: ignore[arg-type] + def test_relative_paths(self) -> None: + tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None}) + paths = [node.path for node in tree.subtree] + assert paths == [ + "/", + "/foo", + "/bar", + "/x", + "/x/y", + ] + + def test_root_keys(self): + ds = Dataset({"x": 1}) + expected = DataTree(dataset=ds) + + actual = DataTree.from_dict({"": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({".": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({"/": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({"./": ds}) + assert_identical(actual, expected) + + with pytest.raises( + ValueError, match="multiple entries found corresponding to the root node" + ): + DataTree.from_dict({"": ds, "/": ds}) + + def test_name(self): + tree = DataTree.from_dict({"/": None}, name="foo") + assert tree.name == "foo" + + tree = DataTree.from_dict({"/": DataTree()}, name="foo") + assert tree.name == "foo" + + tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo") + assert tree.name == "foo" + class TestDatasetView: def test_view_contents(self) -> None: