Skip to content

Commit

Permalink
Start handling keep_attrs in rolling class constructors (pydata#3376)
Browse files Browse the repository at this point in the history
  • Loading branch information
amcnicho committed Feb 27, 2020
1 parent af7bd10 commit 4a45840
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import dtypes, duck_array_ops, utils
from .dask_array_ops import dask_rolling_wrapper
from .ops import inject_reduce_methods
from .options import _get_keep_attrs
from .pycompat import dask_array_type

try:
Expand Down Expand Up @@ -42,10 +43,10 @@ class Rolling:
DataArray.rolling
"""

__slots__ = ("obj", "window", "min_periods", "center", "dim")
_attributes = ("window", "min_periods", "center", "dim")
__slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs")
_attributes = ("window", "min_periods", "center", "dim", "keep_attrs")

def __init__(self, obj, windows, min_periods=None, center=False):
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
"""
Moving window object.
Expand All @@ -65,6 +66,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
object will be returned without attributes.
Returns
-------
Expand All @@ -89,6 +94,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
self.center = center
self.dim = dim

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
self.keep_attrs = keep_attrs

@property
def _min_periods(self):
return self.min_periods if self.min_periods is not None else self.window
Expand Down Expand Up @@ -143,7 +152,7 @@ def count(self):
class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)

def __init__(self, obj, windows, min_periods=None, center=False):
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
Expand All @@ -165,6 +174,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
object will be returned without attributes.
Returns
-------
Expand All @@ -177,7 +190,9 @@ def __init__(self, obj, windows, min_periods=None, center=False):
Dataset.rolling
Dataset.groupby
"""
super().__init__(obj, windows, min_periods=min_periods, center=center)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
super().__init__(obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs)

self.window_labels = self.obj[self.dim]

Expand Down Expand Up @@ -374,7 +389,7 @@ def _numpy_or_bottleneck_reduce(
class DatasetRolling(Rolling):
__slots__ = ("rollings",)

def __init__(self, obj, windows, min_periods=None, center=False):
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
Expand All @@ -396,6 +411,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
object will be returned without attributes.
Returns
-------
Expand All @@ -408,15 +427,15 @@ def __init__(self, obj, windows, min_periods=None, center=False):
Dataset.groupby
DataArray.groupby
"""
super().__init__(obj, windows, min_periods, center)
super().__init__(obj, windows, min_periods, center, keep_attrs)
if self.dim not in self.obj.dims:
raise KeyError(self.dim)
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
if self.dim in da.dims:
self.rollings[key] = DataArrayRolling(da, windows, min_periods, center)
self.rollings[key] = DataArrayRolling(da, windows, min_periods, center, keep_attrs)

def _dataset_implementation(self, func, **kwargs):
from .dataset import Dataset
Expand Down Expand Up @@ -466,7 +485,7 @@ def _numpy_or_bottleneck_reduce(
**kwargs,
)

def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None):
"""
Convert this rolling object to xr.Dataset,
where the window dimension is stacked as a new dimension
Expand All @@ -487,6 +506,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):

from .dataset import Dataset

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dataset = {}
for key, da in self.obj.data_vars.items():
if self.dim in da.dims:
Expand All @@ -509,10 +531,10 @@ class Coarsen:
DataArray.coarsen
"""

__slots__ = ("obj", "boundary", "coord_func", "windows", "side", "trim_excess")
__slots__ = ("obj", "boundary", "coord_func", "windows", "side", "trim_excess", "keep_attrs")
_attributes = ("windows", "side", "trim_excess")

def __init__(self, obj, windows, boundary, side, coord_func):
def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs):
"""
Moving window object.
Expand Down

0 comments on commit 4a45840

Please sign in to comment.