From 1048df26386c5355469e08cecfed11696d30258f Mon Sep 17 00:00:00 2001 From: Michael Delgado Date: Tue, 12 Apr 2022 08:33:04 -0700 Subject: [PATCH] allow other and drop arguments in where (gh#6466) (#6467) * allow other and drop arguments in where (gh#6466) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add whatsnew entry to bug fixes? * drop extraneous edit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 3 +++ xarray/core/common.py | 14 +++++++++----- xarray/tests/test_dataset.py | 7 +++++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b609e0cde5e..85392779163 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -78,6 +78,9 @@ Bug fixes By `Stan West `_. - Fix bug in :py:func:`where` when passing non-xarray objects with ``keep_attrs=True``. (:issue:`6444`, :pull:`6461`) By `Sam Levang `_. +- Allow passing both ``other`` and ``drop=True`` arguments to ``xr.DataArray.where`` + and ``xr.Dataset.where`` (:pull:`6466`, :pull:`6467`). + By `Michael Delgado `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index c33db4a62ea..817b63161bc 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1197,8 +1197,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): By default, these locations filled with NA. drop : bool, optional If True, coordinate labels that only correspond to False values of - the condition are dropped from the result. Mutually exclusive with - ``other``. + the condition are dropped from the result. Returns ------- @@ -1251,6 +1250,14 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): [15., nan, nan, nan]]) Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 4, -1, drop=True) + + array([[ 0, 1, 2, 3], + [ 5, 6, 7, -1], + [10, 11, -1, -1], + [15, -1, -1, -1]]) + Dimensions without coordinates: x, y + See Also -------- numpy.where : corresponding numpy function @@ -1264,9 +1271,6 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): cond = cond(self) if drop: - if other is not dtypes.NA: - raise ValueError("cannot set `other` if drop=True") - if not isinstance(cond, (Dataset, DataArray)): raise TypeError( f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 25adda2df84..a905e0c2210 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4548,8 +4548,11 @@ def test_where_other(self): actual = ds.where(lambda x: x > 1, -1) assert_equal(expected, actual) - with pytest.raises(ValueError, match=r"cannot set"): - ds.where(ds > 1, other=0, drop=True) + actual = ds.where(ds > 1, other=-1, drop=True) + expected_nodrop = ds.where(ds > 1, -1) + _, expected = xr.align(actual, expected_nodrop, join="left") + assert_equal(actual, expected) + assert actual.a.dtype == int with pytest.raises(ValueError, match=r"cannot align .* are not equal"): ds.where(ds > 1, ds.isel(x=slice(3)))