diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 035179393e0..0dbfe3f8837 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -7147,16 +7147,17 @@ def to_dask_dataframe( # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) - def drop_attrs(self) -> Self: + def drop_attrs(self, deep: bool = True) -> Self: """ Removes all attributes from the DataArray. + Parameters + ---------- + deep : bool, default True + Removes attributes from coordinates. + Returns ------- DataArray """ - self = self.copy() - - self.attrs = {} - - return self + return self._to_temp_dataset().drop_attrs(deep=deep).to_array() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 97a7c0f8de8..0217c414c59 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10339,10 +10339,15 @@ def resample( **indexer_kwargs, ) - def drop_attrs(self) -> Self: + def drop_attrs(self, deep: bool = True) -> Self: """ Removes all attributes from the Dataset and its variables. + Parameters + ---------- + deep : bool, default True + Removes attributes from all variables. + Returns ------- Dataset @@ -10350,6 +10355,9 @@ def drop_attrs(self) -> Self: # Remove attributes from the dataset self = self._replace(attrs={}) + if not deep: + return self + # Remove attributes from each variable in the dataset for var in self.variables: # variables don't have a `._replace` method, so we copy and then remove @@ -10359,7 +10367,6 @@ def drop_attrs(self) -> Self: self[var].attrs = {} new_idx_variables = {} - # Not sure this is the most elegant way of doing this, but it works. # (Should we have a more general "map over all variables, including # indexes" approach?) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fd575a6171c..82c9be0eece 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4375,13 +4375,16 @@ def test_drop_attrs(self) -> None: pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z" ) ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2) + assert ds.attrs != {} + assert ds["var1"].attrs != {} + assert ds["y"].attrs != {} assert ds.coords["y"].attrs != {} original = ds.copy(deep=True) result = ds.drop_attrs() assert result.attrs == {} - assert result["x"].attrs == {} + assert result["var1"].attrs == {} assert result["y"].attrs == {} assert list(result.data_vars) == list(ds.data_vars) assert list(result.coords) == list(ds.coords) @@ -4392,6 +4395,14 @@ def test_drop_attrs(self) -> None: # can't currently contain `attrs`, so we can't test those.) assert ds.coords["y"].attrs != {} + # Test for deep=False + result_shallow = ds.drop_attrs(deep=False) + assert result_shallow.attrs == {} + assert result_shallow["var1"].attrs != {} + assert result_shallow["y"].attrs != {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "):