Skip to content

Commit

Permalink
Add a deep kwarg (default True?)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty committed Oct 17, 2023
1 parent 8baff44 commit cf03f4b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
13 changes: 7 additions & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7147,16 +7147,17 @@ def to_dask_dataframe(
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = utils.UncachedAccessor(StringAccessor["DataArray"])

def drop_attrs(self) -> Self:
def drop_attrs(self, deep: bool = True) -> Self:
"""
Removes all attributes from the DataArray.
Parameters
----------
deep : bool, default True
Removes attributes from coordinates.
Returns
-------
DataArray
"""
self = self.copy()

self.attrs = {}

return self
return self._to_temp_dataset().drop_attrs(deep=deep).to_array()
11 changes: 9 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10339,17 +10339,25 @@ def resample(
**indexer_kwargs,
)

def drop_attrs(self) -> Self:
def drop_attrs(self, deep: bool = True) -> Self:
"""
Removes all attributes from the Dataset and its variables.
Parameters
----------
deep : bool, default True
Removes attributes from all variables.
Returns
-------
Dataset
"""
# Remove attributes from the dataset
self = self._replace(attrs={})

if not deep:
return self

# Remove attributes from each variable in the dataset
for var in self.variables:
# variables don't have a `._replace` method, so we copy and then remove
Expand All @@ -10359,7 +10367,6 @@ def drop_attrs(self) -> Self:
self[var].attrs = {}

new_idx_variables = {}

# Not sure this is the most elegant way of doing this, but it works.
# (Should we have a more general "map over all variables, including
# indexes" approach?)
Expand Down
13 changes: 12 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4375,13 +4375,16 @@ def test_drop_attrs(self) -> None:
pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z"
)
ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2)
assert ds.attrs != {}
assert ds["var1"].attrs != {}
assert ds["y"].attrs != {}
assert ds.coords["y"].attrs != {}

original = ds.copy(deep=True)
result = ds.drop_attrs()

assert result.attrs == {}
assert result["x"].attrs == {}
assert result["var1"].attrs == {}
assert result["y"].attrs == {}
assert list(result.data_vars) == list(ds.data_vars)
assert list(result.coords) == list(ds.coords)
Expand All @@ -4392,6 +4395,14 @@ def test_drop_attrs(self) -> None:
# can't currently contain `attrs`, so we can't test those.)
assert ds.coords["y"].attrs != {}

# Test for deep=False
result_shallow = ds.drop_attrs(deep=False)
assert result_shallow.attrs == {}
assert result_shallow["var1"].attrs != {}
assert result_shallow["y"].attrs != {}
assert list(result.data_vars) == list(ds.data_vars)
assert list(result.coords) == list(ds.coords)

def test_assign_multiindex_level(self) -> None:
data = create_test_multiindex()
with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "):
Expand Down

0 comments on commit cf03f4b

Please sign in to comment.