From d528a8d33c08cac171fe8189831390bca85ccc5d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 18 Sep 2023 22:49:44 +0200 Subject: [PATCH] improved typing of rolling instance attrs --- xarray/core/rolling.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index d49cb6e13a4..fa1f146a8a5 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int def __init__( self, @@ -91,8 +96,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -100,7 +105,7 @@ def __init__( self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj: T_Xarray = obj + self.obj = obj missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) if missing_dims: @@ -811,6 +816,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -852,12 +861,15 @@ def __init__( f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: