Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Self rather than concrete types, remove casts #8216

Merged
merged 21 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ source = ["xarray"]
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]

[tool.mypy]
enable_error_code = "redundant-self"
exclude = 'xarray/util/generate_.*\.py'
files = "xarray"
show_error_codes = true
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/accessor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,7 +2386,7 @@ def _partitioner(

# _apply breaks on an empty array in this case
if not self._obj.size:
return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value]
return self._obj.copy().expand_dims({dim: 0}, axis=-1)

arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype)

Expand Down
37 changes: 15 additions & 22 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
DatetimeLike,
DTypeLikeSave,
ScalarOrArray,
Self,
SideOptions,
T_Chunks,
T_DataWithCoords,
Expand Down Expand Up @@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin):
__slots__ = ("_close",)

def squeeze(
self: T_DataWithCoords,
self,
dim: Hashable | Iterable[Hashable] | None = None,
drop: bool = False,
axis: int | Iterable[int] | None = None,
) -> T_DataWithCoords:
) -> Self:
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""Return a new object with squeezed data.

Parameters
Expand Down Expand Up @@ -414,12 +415,12 @@ def squeeze(
return self.isel(drop=drop, **{d: 0 for d in dims})

def clip(
self: T_DataWithCoords,
self,
min: ScalarOrArray | None = None,
max: ScalarOrArray | None = None,
*,
keep_attrs: bool | None = None,
) -> T_DataWithCoords:
) -> Self:
"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.
Expand Down Expand Up @@ -472,10 +473,10 @@ def _calc_assign_results(
return {k: v(self) if callable(v) else v for k, v in kwargs.items()}

def assign_coords(
self: T_DataWithCoords,
self,
coords: Mapping[Any, Any] | None = None,
**coords_kwargs: Any,
) -> T_DataWithCoords:
) -> Self:
"""Assign new coordinates to this object.

Returns a new object with all the original data in addition to the new
Expand Down Expand Up @@ -620,9 +621,7 @@ def assign_coords(
data.coords.update(results)
return data

def assign_attrs(
self: T_DataWithCoords, *args: Any, **kwargs: Any
) -> T_DataWithCoords:
def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
"""Assign new attrs to this object.

Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``.
Expand Down Expand Up @@ -1061,9 +1060,7 @@ def _resample(
restore_coord_dims=restore_coord_dims,
)

def where(
self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False
) -> T_DataWithCoords:
def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
"""Filter elements from this object according to a condition.

This operation follows the normal broadcasting and alignment rules that
Expand Down Expand Up @@ -1205,9 +1202,7 @@ def close(self) -> None:
self._close()
self._close = None

def isnull(
self: T_DataWithCoords, keep_attrs: bool | None = None
) -> T_DataWithCoords:
def isnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is a missing value.

Parameters
Expand Down Expand Up @@ -1250,9 +1245,7 @@ def isnull(
keep_attrs=keep_attrs,
)

def notnull(
self: T_DataWithCoords, keep_attrs: bool | None = None
) -> T_DataWithCoords:
def notnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is not a missing value.

Parameters
Expand Down Expand Up @@ -1295,7 +1288,7 @@ def notnull(
keep_attrs=keep_attrs,
)

def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
def isin(self, test_elements: Any) -> Self:
"""Tests each value in the array for whether it is in test elements.

Parameters
Expand Down Expand Up @@ -1344,15 +1337,15 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
)

def astype(
self: T_DataWithCoords,
self,
dtype,
*,
order=None,
casting=None,
subok=None,
copy=None,
keep_attrs=True,
) -> T_DataWithCoords:
) -> Self:
"""
Copy of the xarray object, with data cast to a specified type.
Leaves coordinate dtype unchanged.
Expand Down Expand Up @@ -1419,7 +1412,7 @@ def astype(
dask="allowed",
)

def __enter__(self: T_DataWithCoords) -> T_DataWithCoords:
def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
Expand Down
8 changes: 3 additions & 5 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Union, cast, overload
from typing import TYPE_CHECKING, Any, Union, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -504,8 +504,7 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
# TODO: Overriding type because .expand_dims has incorrect typing:
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]
datasets = [ds.expand_dims(dim) for ds in datasets]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
Expand Down Expand Up @@ -708,8 +707,7 @@ def _dataarray_concat(
if compat == "identical":
raise ValueError("array names not identical")
else:
# TODO: Overriding type because .rename has incorrect typing:
arr = cast(T_DataArray, arr.rename(name))
arr = arr.rename(name)
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ def copy(
variables = {
k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items()
}
# Currently this requires using the concrete type of `Coordinates`, rather than
# `Self`, ref #8216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why doesn't type(self) work here?

        return type(self)._construct_direct(
            coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
        )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went to answer this question, and then got it working :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK actually the typing worked but the tests didn't; switching back. Added a TODO but won't chase it down atm...

return Coordinates._construct_direct(
coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
)
Expand Down
Loading
Loading