Skip to content

Commit

Permalink
Allow passing of positional arguments in apply for Groupby objects (#…
Browse files Browse the repository at this point in the history
…2413)

* Allow passing of positional arguments to `func` in `XGroupBy.apply`

* Fix the documentation call to `func` to show the actual call.

* Add changes to whats-new.rst

* Fix typo

* Allow passing args to func from DatasetResample and DataArrayResample

* Update whats-new with Resample changes

* Add tests for func arguments in groupby

* Add tests for apply func arguments in resample

* flake8 fixes
  • Loading branch information
maaleske authored and shoyer committed Dec 24, 2018
1 parent 7fcb80f commit d8d87d2
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 10 deletions.
8 changes: 8 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ Enhancements
- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the
``loffset`` kwarg just like Pandas.
By `Deepak Cherian <https://github.com/dcherian>`_
- The `apply` methods for `DatasetGroupBy`, `DataArrayGroupBy`,
`DatasetResample` and `DataArrayResample` can now pass positional arguments to
the applied function.
By `Matti Eskelinen <https://github.com/maaleske>`_.
- 0d slices of ndarrays are now obtained directly through indexing, rather than
extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel
Wennberg <https://github.com/danielwe>`_.
Expand Down Expand Up @@ -260,13 +264,17 @@ Announcements of note:
for more details.
- We have a new :doc:`roadmap` that outlines our future development plans.

- `Dataset.apply` now properly documents the way `func` is called.
By `Matti Eskelinen <https://github.com/maaleske>`_.

Enhancements
~~~~~~~~~~~~

- :py:meth:`~xarray.DataArray.differentiate` and
:py:meth:`~xarray.Dataset.differentiate` are newly added.
(:issue:`1332`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Default colormap for sequential and divergent data can now be set via
:py:func:`~xarray.set_options()`
(:issue:`2394`)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,8 +2953,8 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs):
Parameters
----------
func : function
Function which can be called in the form `f(x, **kwargs)` to
transform each DataArray `x` in this dataset into another
Function which can be called in the form `func(x, *args, **kwargs)`
to transform each DataArray `x` in this dataset into another
DataArray.
keep_attrs : bool, optional
If True, the dataset's attributes (`attrs`) will be copied from
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def lookup_order(dimension):
new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order)

def apply(self, func, shortcut=False, **kwargs):
def apply(self, func, shortcut=False, args=(), **kwargs):
"""Apply a function over each array in the group and concatenate them
together into a new array.
Expand Down Expand Up @@ -532,6 +532,8 @@ def apply(self, func, shortcut=False, **kwargs):
If these conditions are satisfied `shortcut` provides significant
speedup. This should be the case for many common groupby operations
(e.g., applying numpy ufuncs).
args : tuple, optional
Positional arguments passed to `func`.
**kwargs
Used to call `func(ar, **kwargs)` for each array `ar`.
Expand All @@ -544,7 +546,7 @@ def apply(self, func, shortcut=False, **kwargs):
grouped = self._iter_grouped_shortcut()
else:
grouped = self._iter_grouped()
applied = (maybe_wrap_array(arr, func(arr, **kwargs))
applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs))
for arr in grouped)
return self._combine(applied, shortcut=shortcut)

Expand Down Expand Up @@ -642,7 +644,7 @@ def wrapped_func(self, dim=DEFAULT_DIMS, axis=None,


class DatasetGroupBy(GroupBy, ImplementsDatasetReduce):
def apply(self, func, **kwargs):
def apply(self, func, args=(), **kwargs):
"""Apply a function over each Dataset in the group and concatenate them
together into a new Dataset.
Expand All @@ -661,6 +663,8 @@ def apply(self, func, **kwargs):
----------
func : function
Callable to apply to each sub-dataset.
args : tuple, optional
Positional arguments to pass to `func`.
**kwargs
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.
Expand All @@ -670,7 +674,7 @@ def apply(self, func, **kwargs):
The result of splitting, applying and combining this dataset.
"""
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
return self._combine(applied)

def _combine(self, applied):
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, *args, **kwargs):
"('{}')! ".format(self._resample_dim, self._dim))
super(DataArrayResample, self).__init__(*args, **kwargs)

def apply(self, func, shortcut=False, **kwargs):
def apply(self, func, shortcut=False, args=(), **kwargs):
"""Apply a function over each array in the group and concatenate them
together into a new array.
Expand Down Expand Up @@ -158,6 +158,8 @@ def apply(self, func, shortcut=False, **kwargs):
If these conditions are satisfied `shortcut` provides significant
speedup. This should be the case for many common groupby operations
(e.g., applying numpy ufuncs).
args : tuple, optional
Positional arguments passed on to `func`.
**kwargs
Used to call `func(ar, **kwargs)` for each array `ar`.
Expand All @@ -167,7 +169,7 @@ def apply(self, func, shortcut=False, **kwargs):
The result of splitting, applying and combining this array.
"""
combined = super(DataArrayResample, self).apply(
func, shortcut=shortcut, **kwargs)
func, shortcut=shortcut, args=args, **kwargs)

# If the aggregation function didn't drop the original resampling
# dimension, then we need to do so before we can rename the proxy
Expand Down Expand Up @@ -240,7 +242,7 @@ def __init__(self, *args, **kwargs):
"('{}')! ".format(self._resample_dim, self._dim))
super(DatasetResample, self).__init__(*args, **kwargs)

def apply(self, func, **kwargs):
def apply(self, func, args=(), **kwargs):
"""Apply a function over each Dataset in the groups generated for
resampling and concatenate them together into a new Dataset.
Expand All @@ -259,6 +261,8 @@ def apply(self, func, **kwargs):
----------
func : function
Callable to apply to each sub-dataset.
args : tuple, optional
Positional arguments passed on to `func`.
**kwargs
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.
Expand All @@ -268,7 +272,7 @@ def apply(self, func, **kwargs):
The result of splitting, applying and combining this dataset.
"""
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
combined = self._combine(applied)

return combined.rename({self._resample_dim: self._dim})
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,17 @@ def test_resample(self):
with raises_regex(ValueError, 'index must be monotonic'):
array[[2, 0, 1]].resample(time='1D')

def test_da_resample_func_args(self):

def func(arg1, arg2, arg3=0.):
return arg1.mean('time') + arg2 + arg3

times = pd.date_range('2000', periods=3, freq='D')
da = xr.DataArray([1., 1., 1.], coords=[times], dims=['time'])
expected = xr.DataArray([3., 3., 3.], coords=[times], dims=['time'])
actual = da.resample(time='D').apply(func, args=(1.,), arg3=1.)
assert_identical(actual, expected)

@requires_cftime
def test_resample_cftimeindex(self):
cftime = _import_cftime()
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,19 @@ def test_resample_old_api(self):
with raises_regex(TypeError, r'resample\(\) no longer supports'):
ds.resample('1D', dim='time')

def test_ds_resample_apply_func_args(self):

def func(arg1, arg2, arg3=0.):
return arg1.mean('time') + arg2 + arg3

times = pd.date_range('2000', freq='D', periods=3)
ds = xr.Dataset({'foo': ('time', [1., 1., 1.]),
'time': times})
expected = xr.Dataset({'foo': ('time', [3., 3., 3.]),
'time': times})
actual = ds.resample(time='D').apply(func, args=(1.,), arg3=1.)
assert_identical(expected, actual)

def test_to_array(self):
ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]),
coords={'c': 42}, attrs={'Conventions': 'None'})
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,26 @@ def test_groupby_input_mutation():
assert_identical(array, array_copy) # should not modify inputs


def test_da_groupby_apply_func_args():

def func(arg1, arg2, arg3=0):
return arg1 + arg2 + arg3

array = xr.DataArray([1, 1, 1], [('x', [1, 2, 3])])
expected = xr.DataArray([3, 3, 3], [('x', [1, 2, 3])])
actual = array.groupby('x').apply(func, args=(1,), arg3=1)
assert_identical(expected, actual)


def test_ds_groupby_apply_func_args():

def func(arg1, arg2, arg3=0):
return arg1 + arg2 + arg3

dataset = xr.Dataset({'foo': ('x', [1, 1, 1])}, {'x': [1, 2, 3]})
expected = xr.Dataset({'foo': ('x', [3, 3, 3])}, {'x': [1, 2, 3]})
actual = dataset.groupby('x').apply(func, args=(1,), arg3=1)
assert_identical(expected, actual)


# TODO: move other groupby tests from test_dataset and test_dataarray over here

0 comments on commit d8d87d2

Please sign in to comment.