Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy issues & reenable in tests #6581

Merged
merged 12 commits into from
May 8, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,9 @@ jobs:
run: |
python -m pip install mypy

# Temporarily overriding to be true due to https://github.com/pydata/xarray/issues/6551
# python -m mypy --install-types --non-interactive
- name: Run mypy
run: |
python -m mypy --install-types --non-interactive || true
python -m mypy --install-types --non-interactive

min-version-policy:
name: Minimum Version Policy
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore


DATAARRAY_NAME = "__xarray_dataarray_name__"
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from dask.utils import SerializableLock
except ImportError:
# no need to worry about serializing the lock
SerializableLock = threading.Lock
SerializableLock = threading.Lock # type: ignore

try:
from dask.distributed import Lock as DistributedLock
except ImportError:
DistributedLock = None
DistributedLock = None # type: ignore


# Locks used by multiple backends.
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from .variable import Variable
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore

# DatasetOpsMixin etc. are parent classes of Dataset etc.
# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
Expand Down
27 changes: 14 additions & 13 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Iterable,
Mapping,
Sequence,
overload,
)

import numpy as np
Expand Down Expand Up @@ -1846,24 +1845,26 @@ def where(cond, x, y, keep_attrs=None):
)


@overload
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
...
# These overloads seem not to work — mypy says it can't find a matching overload for
# `DataArray` & `DataArray`, despite that being in the first overload. Would be nice to
# have overloaded functions rather than just `T_Xarray` for everything.

# @overload
# def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
# ...

@overload
def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
...

# @overload
# def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
# ...

@overload
def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
...

# @overload
# def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
# ...


def polyval(
coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree"
) -> T_Xarray:
def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim=None) -> T_Xarray:
"""Evaluate a polynomial at specific values

Parameters
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
try:
import dask.array as da
except ImportError:
da = None
da = None # type: ignore


def _validate_pad_output_shape(input_shape, pad_width, output_shape):
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import pandas as pd

from xarray.backends.common import AbstractDataStore, ArrayWriter
max-sixty marked this conversation as resolved.
Show resolved Hide resolved

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex
from ..plot.plot import _PlotMethods
Expand Down Expand Up @@ -67,7 +69,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore
try:
from cdms2 import Variable as cdms2_Variable
except ImportError:
Expand Down Expand Up @@ -2875,7 +2877,9 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
isnull = pd.isnull(values)
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

def to_netcdf(self, *args, **kwargs) -> bytes | Delayed | None:
def to_netcdf(
self, *args, **kwargs
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""Write DataArray contents to a netCDF file.

All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`.
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import xarray as xr

from ..backends.common import ArrayWriter
from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
Expand Down Expand Up @@ -110,7 +111,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down Expand Up @@ -1686,7 +1687,7 @@ def to_netcdf(
unlimited_dims: Iterable[Hashable] = None,
compute: bool = True,
invalid_netcdf: bool = False,
) -> bytes | Delayed | None:
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""Write dataset contents to a netCDF file.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import dask.array as dask_array
from dask.base import tokenize
except ImportError:
dask_array = None
dask_array = None # type: ignore


def _dask_or_eager_func(
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import dask_array_compat
except ImportError:
dask_array = None
dask_array = None # type: ignore[assignment]
dask_array_compat = None # type: ignore[assignment]


Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore


T_Dataset = TypeVar("T_Dataset", bound="Dataset")
Expand Down
7 changes: 5 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import functools
import operator
import pickle
Expand All @@ -22,6 +24,7 @@
unified_dim_sizes,
)
from xarray.core.pycompat import dask_version
from xarray.core.types import T_Xarray

from . import has_dask, raise_if_dask_computes, requires_dask

Expand Down Expand Up @@ -2009,14 +2012,14 @@ def test_where_attrs() -> None:
),
],
)
def test_polyval(use_dask, x, coeffs, expected) -> None:
def test_polyval(use_dask, x: T_Xarray, coeffs: T_Xarray, expected) -> None:
if use_dask:
if not has_dask:
pytest.skip("requires dask")
coeffs = coeffs.chunk({"degree": 2})
x = x.chunk({"x": 2})
with raise_if_dask_computes():
actual = xr.polyval(x, coeffs)
actual = xr.polyval(coord=x, coeffs=coeffs)
xr.testing.assert_allclose(actual, expected)


Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
from dask.array import from_array as dask_from_array
except ImportError:
dask_from_array = lambda x: x
dask_from_array = lambda x: x # type: ignore

try:
import pint
Expand Down
2 changes: 1 addition & 1 deletion xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def inplace():
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore

# DatasetOpsMixin etc. are parent classes of Dataset etc.
# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
Expand Down