Skip to content

Commit

Permalink
ENH: Add numeric_only to window ops (pandas-dev#47265)
Browse files Browse the repository at this point in the history
* ENH: Add numeric_only to window ops

* Fix corr/cov for Series; add tests
  • Loading branch information
rhshadrach authored and yehoshuadimarsky committed Jul 13, 2022
1 parent e05b362 commit 883adba
Show file tree
Hide file tree
Showing 9 changed files with 659 additions and 87 deletions.
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,9 @@ gained the ``numeric_only`` argument.
- :meth:`.Resampler.sem`
- :meth:`.Resampler.std`
- :meth:`.Resampler.var`
- :meth:`DataFrame.rolling` operations
- :meth:`DataFrame.expanding` operations
- :meth:`DataFrame.ewm` operations

.. _whatsnew_150.deprecations.other:

Expand Down
9 changes: 9 additions & 0 deletions pandas/core/window/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def create_section_header(header: str) -> str:
"""
).replace("\n", "", 1)

kwargs_numeric_only = dedent(
"""
numeric_only : bool, default False
Include only float, int, boolean columns.
.. versionadded:: 1.5.0\n
"""
).replace("\n", "", 1)

args_compat = dedent(
"""
*args
Expand Down
71 changes: 58 additions & 13 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from pandas.util._decorators import doc
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import is_datetime64_ns_dtype
from pandas.core.dtypes.common import (
is_datetime64_ns_dtype,
is_numeric_dtype,
)
from pandas.core.dtypes.missing import isna

import pandas.core.common as common # noqa: PDF018
Expand All @@ -45,6 +48,7 @@
args_compat,
create_section_header,
kwargs_compat,
kwargs_numeric_only,
numba_notes,
template_header,
template_returns,
Expand Down Expand Up @@ -518,6 +522,7 @@ def aggregate(self, func, *args, **kwargs):
@doc(
template_header,
create_section_header("Parameters"),
kwargs_numeric_only,
args_compat,
window_agg_numba_parameters(),
kwargs_compat,
Expand All @@ -531,7 +536,14 @@ def aggregate(self, func, *args, **kwargs):
aggregation_description="(exponential weighted moment) mean",
agg_method="mean",
)
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
def mean(
self,
numeric_only: bool = False,
*args,
engine=None,
engine_kwargs=None,
**kwargs,
):
if maybe_use_numba(engine):
if self.method == "single":
func = generate_numba_ewm_func
Expand All @@ -545,7 +557,7 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
deltas=tuple(self._deltas),
normalize=True,
)
return self._apply(ewm_func)
return self._apply(ewm_func, name="mean")
elif engine in ("cython", None):
if engine_kwargs is not None:
raise ValueError("cython engine does not accept engine_kwargs")
Expand All @@ -560,13 +572,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
deltas=deltas,
normalize=True,
)
return self._apply(window_func)
return self._apply(window_func, name="mean", numeric_only=numeric_only)
else:
raise ValueError("engine must be either 'numba' or 'cython'")

@doc(
template_header,
create_section_header("Parameters"),
kwargs_numeric_only,
args_compat,
window_agg_numba_parameters(),
kwargs_compat,
Expand All @@ -580,7 +593,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
aggregation_description="(exponential weighted moment) sum",
agg_method="sum",
)
def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
def sum(
self,
numeric_only: bool = False,
*args,
engine=None,
engine_kwargs=None,
**kwargs,
):
if not self.adjust:
raise NotImplementedError("sum is not implemented with adjust=False")
if maybe_use_numba(engine):
Expand All @@ -596,7 +616,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
deltas=tuple(self._deltas),
normalize=False,
)
return self._apply(ewm_func)
return self._apply(ewm_func, name="sum")
elif engine in ("cython", None):
if engine_kwargs is not None:
raise ValueError("cython engine does not accept engine_kwargs")
Expand All @@ -611,7 +631,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
deltas=deltas,
normalize=False,
)
return self._apply(window_func)
return self._apply(window_func, name="sum", numeric_only=numeric_only)
else:
raise ValueError("engine must be either 'numba' or 'cython'")

Expand All @@ -624,6 +644,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
Use a standard estimation bias correction.
"""
).replace("\n", "", 1),
kwargs_numeric_only,
args_compat,
kwargs_compat,
create_section_header("Returns"),
Expand All @@ -634,9 +655,18 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
aggregation_description="(exponential weighted moment) standard deviation",
agg_method="std",
)
def std(self, bias: bool = False, *args, **kwargs):
def std(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
nv.validate_window_func("std", args, kwargs)
return zsqrt(self.var(bias=bias, **kwargs))
if (
numeric_only
and self._selected_obj.ndim == 1
and not is_numeric_dtype(self._selected_obj.dtype)
):
# Raise directly so error message says std instead of var
raise NotImplementedError(
f"{type(self).__name__}.std does not implement numeric_only"
)
return zsqrt(self.var(bias=bias, numeric_only=numeric_only, **kwargs))

def vol(self, bias: bool = False, *args, **kwargs):
warnings.warn(
Expand All @@ -658,6 +688,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
Use a standard estimation bias correction.
"""
).replace("\n", "", 1),
kwargs_numeric_only,
args_compat,
kwargs_compat,
create_section_header("Returns"),
Expand All @@ -668,7 +699,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
aggregation_description="(exponential weighted moment) variance",
agg_method="var",
)
def var(self, bias: bool = False, *args, **kwargs):
def var(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
nv.validate_window_func("var", args, kwargs)
window_func = window_aggregations.ewmcov
wfunc = partial(
Expand All @@ -682,7 +713,7 @@ def var(self, bias: bool = False, *args, **kwargs):
def var_func(values, begin, end, min_periods):
return wfunc(values, begin, end, min_periods, values)

return self._apply(var_func)
return self._apply(var_func, name="var", numeric_only=numeric_only)

@doc(
template_header,
Expand All @@ -703,6 +734,7 @@ def var_func(values, begin, end, min_periods):
Use a standard estimation bias correction.
"""
).replace("\n", "", 1),
kwargs_numeric_only,
kwargs_compat,
create_section_header("Returns"),
template_returns,
Expand All @@ -717,10 +749,13 @@ def cov(
other: DataFrame | Series | None = None,
pairwise: bool | None = None,
bias: bool = False,
numeric_only: bool = False,
**kwargs,
):
from pandas import Series

self._validate_numeric_only("cov", numeric_only)

def cov_func(x, y):
x_array = self._prep_values(x)
y_array = self._prep_values(y)
Expand Down Expand Up @@ -752,7 +787,9 @@ def cov_func(x, y):
)
return Series(result, index=x.index, name=x.name)

return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
return self._apply_pairwise(
self._selected_obj, other, pairwise, cov_func, numeric_only
)

@doc(
template_header,
Expand All @@ -771,6 +808,7 @@ def cov_func(x, y):
observations will be used.
"""
).replace("\n", "", 1),
kwargs_numeric_only,
kwargs_compat,
create_section_header("Returns"),
template_returns,
Expand All @@ -784,10 +822,13 @@ def corr(
self,
other: DataFrame | Series | None = None,
pairwise: bool | None = None,
numeric_only: bool = False,
**kwargs,
):
from pandas import Series

self._validate_numeric_only("corr", numeric_only)

def cov_func(x, y):
x_array = self._prep_values(x)
y_array = self._prep_values(y)
Expand Down Expand Up @@ -825,7 +866,9 @@ def _cov(X, Y):
result = cov / zsqrt(x_var * y_var)
return Series(result, index=x.index, name=x.name)

return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
return self._apply_pairwise(
self._selected_obj, other, pairwise, cov_func, numeric_only
)


class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
Expand Down Expand Up @@ -921,6 +964,7 @@ def corr(
self,
other: DataFrame | Series | None = None,
pairwise: bool | None = None,
numeric_only: bool = False,
**kwargs,
):
return NotImplementedError
Expand All @@ -930,6 +974,7 @@ def cov(
other: DataFrame | Series | None = None,
pairwise: bool | None = None,
bias: bool = False,
numeric_only: bool = False,
**kwargs,
):
return NotImplementedError
Expand Down
Loading

0 comments on commit 883adba

Please sign in to comment.