Skip to content

Commit

Permalink
implement idxmax and idxmin
Browse files Browse the repository at this point in the history
  • Loading branch information
toddrjen committed Mar 27, 2020
1 parent 82f78cc commit c896135
Show file tree
Hide file tree
Showing 7 changed files with 1,214 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ Computation
:py:attr:`~Dataset.any`
:py:attr:`~Dataset.argmax`
:py:attr:`~Dataset.argmin`
:py:attr:`~Dataset.idxmax`
:py:attr:`~Dataset.idxmin`
:py:attr:`~Dataset.max`
:py:attr:`~Dataset.mean`
:py:attr:`~Dataset.median`
Expand Down Expand Up @@ -362,6 +364,8 @@ Computation
:py:attr:`~DataArray.any`
:py:attr:`~DataArray.argmax`
:py:attr:`~DataArray.argmin`
:py:attr:`~DataArray.idxmax`
:py:attr:`~DataArray.idxmin`
:py:attr:`~DataArray.max`
:py:attr:`~DataArray.mean`
:py:attr:`~DataArray.median`
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ New Features
:py:func:`combine_by_coords` and :py:func:`combine_nested` using
combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`)
By `John Omotani <https://github.com/johnomotani>`_
- Implement :py:meth:`~xarray.DataArray.idxmax`, :py:meth:`~xarray.DataArray.idxmin`,
:py:meth:`~xarray.Dataset.idxmax`, :py:meth:`~xarray.Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_


Bug fixes
Expand Down
66 changes: 66 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1338,3 +1339,68 @@ def polyval(coord, coeffs, degree_dim="degree"):
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
)
return (lhs * coeffs).sum(degree_dim)


def _calc_idxminmax(
*,
array,
func: Callable,
dim: Optional[Hashable],
skipna: Optional[bool],
fill_value: Any = None,
keep_attrs: Optional[bool],
**kwargs: Any,
):
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not array.ndim:
raise ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = array.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(f'Dimension "{dim}" not in dimension')
if dim not in array.coords:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = array.isnull().all(dim)
array = array.where(~allna, 0)

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
else:
res = coordarray[
indx,
]

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Put the attrs back in if needed
if keep_attrs:
res.attrs = array.attrs

return res
196 changes: 196 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3508,6 +3508,202 @@ def pad(
)
return self._from_temp_dataset(ds)

def idxmin(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = np.NaN,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "DataArray":
"""Return the coordinate of the minimum value along a dimension.
Returns a new DataArray named after the dimension with the values of
the coordinate along that dimension corresponding to minimum value
along that dimension.
In comparison to `argmin`, this returns the coordinate while `argmin`
returns the index.
Parameters
----------
dim : str, optional
Dimension over which to apply `idxmin`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if `skipna` is False.
keep_attrs : bool, default False
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `idxmin` on this object's data.
Returns
-------
reduced : DataArray
New DataArray object with `idxmin` applied to its data and the
indicated dimension removed.
Examples
--------
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.min()
<xarray.DataArray ()>
array(-2)
>>> array.argmin()
<xarray.DataArray ()>
array(4)
>>> array.idxmin()
<xarray.DataArray 'x' ()>
array('e', dtype='<U1')
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.min(dim="x")
<xarray.DataArray (y: 3)>
array([-2., -4., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmin(dim="x")
<xarray.DataArray (y: 3)>
array([4, 0, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmin(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([16., 0., 4.])
Coordinates:
* y (y) int64 -1 0 1
See also
--------
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)

def idxmax(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = np.NaN,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "DataArray":
"""Return the coordinate of the maximum value along a dimension.
Returns a new DataArray named after the dimension with the values of
the coordinate along that dimension corresponding to maximum value
along that dimension.
In comparison to `argmax`, this returns the coordinate while `argmax`
returns the index.
Parameters
----------
dim : str, optional
Dimension over which to apply `idxmax`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if `skipna` is False.
keep_attrs : bool, default False
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `idxmax` on this object's data.
Returns
-------
reduced : DataArray
New DataArray object with `idxmax` applied to its data and the
indicated dimension removed.
Examples
--------
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.max()
<xarray.DataArray ()>
array(2)
>>> array.argmax()
<xarray.DataArray ()>
array(1)
>>> array.idxmax()
<xarray.DataArray 'x' ()>
array('b', dtype='<U1')
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.max(dim="x")
<xarray.DataArray (y: 3)>
array([2., 2., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmax(dim="x")
<xarray.DataArray (y: 3)>
array([0, 2, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmax(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([0., 4., 4.])
Coordinates:
* y (y) int64 -1 0 1
See also
--------
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = property(StringAccessor)
Expand Down
Loading

0 comments on commit c896135

Please sign in to comment.