From 560a03896edbd38f6a218a6aac88d7534a5bc07c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 12:28:38 -0700 Subject: [PATCH 1/3] DataTree should not be "Generic" DataTree isn't a Generic tree type. It's a specific tree type -- the nodes are DataTree objects. This was resulting in many cases where mypy insisting on explicit type annotations, e.g., `tree: DataTree = DataTree(...)`, which is unnecessary and annoying boilerplate. --- xarray/core/datatree.py | 7 +- xarray/tests/test_datatree.py | 118 +++++++++++++++++----------------- 2 files changed, 62 insertions(+), 63 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3a3fb19daa4..1ea6c43aab9 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -11,7 +11,7 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Generic, Literal, NoReturn, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload from xarray.core import utils from xarray.core.alignment import align @@ -37,7 +37,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS -from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.core.treenode import NamedNode, NodePath from xarray.core.utils import ( Default, Frozen, @@ -369,8 +369,7 @@ class DataTree( MappedDataWithCoords, DataTreeArithmeticMixin, TreeAttrAccessMixin, - Generic[Tree], - Mapping, + Mapping[str, "DataArray | DataTree"], ): """ A tree-like hierarchical collection of xarray objects. diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 1a840dbb9e4..6be6319f3d1 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -16,14 +16,14 @@ class TestTreeCreation: def test_empty(self): - dt: DataTree = DataTree(name="root") + dt = DataTree(name="root") assert dt.name == "root" assert dt.parent is None assert dt.children == {} assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): - dt: DataTree = DataTree() + dt = DataTree() assert dt.name is None def test_bad_names(self): @@ -37,7 +37,7 @@ def test_bad_names(self): class TestFamilyTree: def test_dont_modify_children_inplace(self): # GH issue 9196 - child: DataTree = DataTree() + child = DataTree() DataTree(children={"child": child}) assert child.parent is None @@ -69,8 +69,8 @@ def test_create_full_tree(self, simple_datatree): class TestNames: def test_child_gets_named_on_attach(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) # noqa + sue = DataTree() + mary = DataTree(children={"Sue": sue}) # noqa assert mary.children["Sue"].name == "Sue" @@ -102,7 +102,7 @@ def test_same_tree(self): assert john["/Mary"].same_tree(john["/Kate"]) def test_relative_paths(self): - john: DataTree = DataTree.from_dict( + john = DataTree.from_dict( { "/Mary/Sue": DataTree(), "/Annie": DataTree(), @@ -122,7 +122,7 @@ def test_relative_paths(self): assert sue.relative_to(annie) == "../Mary/Sue" assert sue.relative_to(sue) == "." - evil_kate: DataTree = DataTree() + evil_kate = DataTree() with pytest.raises( NotFoundInTreeError, match="nodes do not lie within the same tree" ): @@ -132,7 +132,7 @@ def test_relative_paths(self): class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) - john: DataTree = DataTree(name="john", data=dat) + john = DataTree(name="john", data=dat) assert_identical(john.to_dataset(), dat) @@ -140,7 +140,7 @@ def test_create_with_data(self): DataTree(name="mary", data="junk") # type: ignore[arg-type] def test_set_data(self): - john: DataTree = DataTree(name="john") + john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.ds = dat # type: ignore[assignment] @@ -150,17 +150,17 @@ def test_set_data(self): john.ds = "junk" # type: ignore[assignment] def test_has_data(self): - john: DataTree = DataTree(name="john", data=xr.Dataset({"a": 0})) + john = DataTree(name="john", data=xr.Dataset({"a": 0})) assert john.has_data - john_no_data: DataTree = DataTree(name="john", data=None) + john_no_data = DataTree(name="john", data=None) assert not john_no_data.has_data def test_is_hollow(self): - john: DataTree = DataTree(data=xr.Dataset({"a": 0})) + john = DataTree(data=xr.Dataset({"a": 0})) assert john.is_hollow - eve: DataTree = DataTree(children={"john": john}) + eve = DataTree(children={"john": john}) assert eve.is_hollow eve.ds = xr.Dataset({"a": 1}) # type: ignore[assignment] @@ -188,7 +188,7 @@ def test_parent_already_has_variable_with_childs_name(self): DataTree.from_dict({"/": xr.Dataset({"a": [0], "b": 1}), "/a": None}) def test_parent_already_has_variable_with_childs_name_update(self): - dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + dt = DataTree(data=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(ValueError, match="already contains a variable named a"): dt.update({"a": DataTree()}) @@ -224,12 +224,12 @@ def test_getitem_node(self): assert folder1["results/highres"].name == "highres" def test_getitem_self(self): - dt: DataTree = DataTree() + dt = DataTree() assert dt["."] is dt def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): @@ -242,20 +242,20 @@ def test_getitem_single_data_variable_from_node(self): assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): - folder1: DataTree = DataTree.from_dict({"/results": DataTree()}, name="folder1") + folder1 = DataTree.from_dict({"/results": DataTree()}, name="folder1") with pytest.raises(KeyError): folder1["results/highres"] def test_getitem_nonexistent_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) with pytest.raises(KeyError): results["pressure"] @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] @pytest.mark.xfail( @@ -263,13 +263,13 @@ def test_getitem_multiple_data_variables(self): ) def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: def test_update(self): - dt: DataTree = DataTree() + dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) assert_equal(dt, expected) @@ -277,13 +277,13 @@ def test_update(self): def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): - dt: DataTree = DataTree() + dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) assert "a" in dt.children child = dt["a"] @@ -404,8 +404,8 @@ def test_copy_with_data(self, create_test_datatree): class TestSetItem: def test_setitem_new_child_node(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary") + john = DataTree(name="john") + mary = DataTree(name="mary") john["mary"] = mary grafted_mary = john["mary"] @@ -413,13 +413,13 @@ def test_setitem_new_child_node(self): assert grafted_mary.name == "mary" def test_setitem_unnamed_child_node_becomes_named(self): - john2: DataTree = DataTree(name="john2") + john2 = DataTree(name="john2") john2["sonny"] = DataTree() assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self): john = DataTree.from_dict({"/Mary/Rose": DataTree()}) - new_rose: DataTree = DataTree(data=xr.Dataset({"x": 0})) + new_rose = DataTree(data=xr.Dataset({"x": 0})) john["Mary/Rose"] = new_rose grafted_rose = john["Mary/Rose"] @@ -427,20 +427,20 @@ def test_setitem_new_grandchild_node(self): assert grafted_rose.name == "Rose" def test_grafted_subtree_retains_name(self): - subtree: DataTree = DataTree(name="original_subtree_name") - root: DataTree = DataTree(name="root") + subtree = DataTree(name="original_subtree_name") + root = DataTree(name="root") root["new_subtree_name"] = subtree # noqa assert subtree.name == "original_subtree_name" def test_setitem_new_empty_node(self): - john: DataTree = DataTree(name="john") + john = DataTree(name="john") john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): - john: DataTree = DataTree.from_dict({"/mary": xr.Dataset()}, name="john") + john = DataTree.from_dict({"/mary": xr.Dataset()}, name="john") john["mary"] = DataTree() assert_identical(john["mary"].to_dataset(), xr.Dataset()) @@ -452,57 +452,57 @@ def test_setitem_overwrite_data_in_node_with_none(self): @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results") + results = DataTree(name="results") results["."] = data assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = data assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results/highres"] = data assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = data assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = var assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = 0 assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): - results: DataTree = DataTree(name="results") + results = DataTree(name="results") results["pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results.ds results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results.ds # What if there is a path to traverse first? - results_with_path: DataTree = DataTree(name="results") + results_with_path = DataTree(name="results") results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results_with_path["highres"].ds results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) @@ -510,7 +510,7 @@ def test_setitem_add_new_variable_to_empty_node(self): def test_setitem_dataarray_replace_existing_node(self): t = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=t) + results = DataTree(name="results", data=t) p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) @@ -567,8 +567,8 @@ def test_full(self, simple_datatree): ] def test_datatree_values(self): - dat1: DataTree = DataTree(data=xr.Dataset({"a": 1})) - expected: DataTree = DataTree() + dat1 = DataTree(data=xr.Dataset({"a": 1})) + expected = DataTree() expected["a"] = dat1 actual = DataTree.from_dict({"a": dat1}) @@ -617,7 +617,7 @@ def test_insertion_order(self): class TestDatasetView: def test_view_contents(self): ds = create_test_data() - dt: DataTree = DataTree(data=ds) + dt = DataTree(data=ds) assert ds.identical( dt.ds ) # this only works because Dataset.identical doesn't check types @@ -648,7 +648,7 @@ def test_immutability(self): def test_methods(self): ds = create_test_data() - dt: DataTree = DataTree(data=ds) + dt = DataTree(data=ds) assert ds.mean().identical(dt.ds.mean()) assert isinstance(dt.ds.mean(), xr.Dataset) @@ -669,7 +669,7 @@ def test_init_via_type(self): dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") - dt: DataTree = DataTree(data=a) + dt = DataTree(data=a) def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) @@ -720,7 +720,7 @@ def test_operation_with_attrs_but_no_data(self): class TestRepr: def test_repr(self): - dt: DataTree = DataTree.from_dict( + dt = DataTree.from_dict( { "/": xr.Dataset( {"e": (("x",), [1.0, 2.0])}, @@ -859,12 +859,12 @@ def test_inconsistent_dims(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c"] = xr.DataArray([3.0], dims=["x"]) - b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) + b = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) with pytest.raises(ValueError, match=expected_msg): DataTree( data=xr.Dataset({"a": (("x",), [1.0, 2.0])}), @@ -896,13 +896,13 @@ def test_inconsistent_child_indexes(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore dt["/b"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b"].ds = xr.Dataset(coords={"x": [2.0]}) - b: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) + b = DataTree(xr.Dataset(coords={"x": [2.0]})) with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) @@ -931,14 +931,14 @@ def test_inconsistent_grandchild_indexes(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore dt["/b/c"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b/c"].ds = xr.Dataset(coords={"x": [2.0]}) - c: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) - b: DataTree = DataTree(children={"c": c}) + c = DataTree(xr.Dataset(coords={"x": [2.0]})) + b = DataTree(children={"c": c}) with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) @@ -965,7 +965,7 @@ def test_inconsistent_grandchild_dims(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c/d"] = xr.DataArray([3.0], dims=["x"]) @@ -993,7 +993,7 @@ def test_drop_nodes(self): assert childless.children == {} def test_assign(self): - dt: DataTree = DataTree() + dt = DataTree() expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) # kwargs form @@ -1040,7 +1040,7 @@ def f(x, tree, y): class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? - dt: DataTree = DataTree.from_dict( + dt = DataTree.from_dict( { "/a/A": None, "/a/B": None, @@ -1058,7 +1058,7 @@ def test_match(self): assert_identical(result, expected) def test_filter(self): - simpsons: DataTree = DataTree.from_dict( + simpsons = DataTree.from_dict( d={ "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), @@ -1177,7 +1177,7 @@ def test_binary_op_on_datatree(self): expected = DataTree.from_dict({"/": ds1 * ds1, "/subnode": ds2 * ds2}) # TODO: Remove ignore when ops.py is migrated? - result: DataTree = dt * dt # type: ignore[operator] + result = dt * dt # type: ignore[operator] assert_equal(result, expected) From aff87f0bd94ca60758394226ac7d0c3ba3c535d7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 13:01:14 -0700 Subject: [PATCH 2/3] Fix type error --- xarray/core/datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1ea6c43aab9..b82e7e34e15 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -704,7 +704,7 @@ def __contains__(self, key: object) -> bool: def __bool__(self) -> bool: return bool(self._data_variables) or bool(self._children) - def __iter__(self) -> Iterator[Hashable]: + def __iter__(self) -> Iterator[str]: return itertools.chain(self._data_variables, self._children) def __array__(self, dtype=None, copy=None): From 24bdc1106d4062c8f0e862bfa15bd00d8fddc337 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 7 Sep 2024 13:06:43 -0700 Subject: [PATCH 3/3] type ignore --- xarray/core/datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index b82e7e34e15..becd3558228 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -705,7 +705,7 @@ def __bool__(self) -> bool: return bool(self._data_variables) or bool(self._children) def __iter__(self) -> Iterator[str]: - return itertools.chain(self._data_variables, self._children) + return itertools.chain(self._data_variables, self._children) # type: ignore def __array__(self, dtype=None, copy=None): raise TypeError(