Skip to content

Commit

Permalink
improved typing of rolling instance attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Sep 18, 2023
1 parent 2a09b3d commit d528a8d
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]):

__slots__ = ("obj", "window", "min_periods", "center", "dim")
_attributes = ("window", "min_periods", "center", "dim")
dim: list[Hashable]
window: list[int]
center: list[bool]
obj: T_Xarray
min_periods: int

def __init__(
self,
Expand Down Expand Up @@ -91,16 +96,16 @@ def __init__(
-------
rolling : type of input argument
"""
self.dim: list[Hashable] = []
self.window: list[int] = []
self.dim = []
self.window = []
for d, w in windows.items():
self.dim.append(d)
if w <= 0:
raise ValueError("window must be > 0")
self.window.append(w)

self.center = self._mapping_to_list(center, default=False)
self.obj: T_Xarray = obj
self.obj = obj

missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims)
if missing_dims:
Expand Down Expand Up @@ -811,6 +816,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
)
_attributes = ("windows", "side", "trim_excess")
obj: T_Xarray
windows: Mapping[Hashable, int]
side: SideOptions | Mapping[Hashable, SideOptions]
boundary: CoarsenBoundaryOptions
coord_func: Mapping[Hashable, str | Callable]

def __init__(
self,
Expand Down Expand Up @@ -852,12 +861,15 @@ def __init__(
f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} "
f"dimensions {tuple(self.obj.dims)}"
)
if not utils.is_dict_like(coord_func):
coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc]

if utils.is_dict_like(coord_func):
coord_func_map = coord_func
else:
coord_func_map = {d: coord_func for d in self.obj.dims}
for c in self.obj.coords:
if c not in coord_func:
coord_func[c] = duck_array_ops.mean # type: ignore[index]
self.coord_func: Mapping[Hashable, str | Callable] = coord_func
if c not in coord_func_map:
coord_func_map[c] = duck_array_ops.mean # type: ignore[index]
self.coord_func = coord_func_map

def _get_keep_attrs(self, keep_attrs):
if keep_attrs is None:
Expand Down

0 comments on commit d528a8d

Please sign in to comment.