Skip to content

Commit

Permalink
Improve typehints of xr.Dataset.__getitem__ (#4144)
Browse files Browse the repository at this point in the history
* Improve typehints of xr.Dataset.__getitem__

Resolves #4125

* Add overload for Mapping behavior

Sadly this is not working with my version of mypy. See python/mypy#7328

* Overload only Hashable inputs

Given mypy's use of overloads, I think this is all we can do. If the argument is not Hashable, then return the Union type as before.

* Lint

* Quote the DataArray to avoid error in py3.6

* Code review

Co-authored-by: crusaderky <[email protected]>
  • Loading branch information
nbren12 and crusaderky authored Jun 15, 2020
1 parent 2ba5300 commit bc5c79e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761 # Must match ci/requirements/*.yml
rev: v0.780 # Must match ci/requirements/*.yml
hooks:
- id: mypy
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements/py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- isort
- lxml # Optional dep of pydap
- matplotlib
- mypy=0.761 # Must match .pre-commit-config.yaml
- mypy=0.780 # Must match .pre-commit-config.yaml
- nc-time-axis
- netcdf4
- numba
Expand Down
17 changes: 15 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -1241,13 +1242,25 @@ def loc(self) -> _LocIndexer:
"""
return _LocIndexer(self)

def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]":
# FIXME https://github.com/python/mypy/issues/7328
@overload
def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore
...

@overload
def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore
...

@overload
def __getitem__(self, key: Any) -> "Dataset":
...

def __getitem__(self, key):
"""Access variables or coordinates this dataset as a
:py:class:`~xarray.DataArray`.
Indexing with a list of names will return a new ``Dataset`` object.
"""
# TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328
if utils.is_dict_like(key):
return self.isel(**cast(Mapping, key))

Expand Down
6 changes: 3 additions & 3 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ class Weighted:
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
...

@overload # noqa: F811
def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811
@overload
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
...

def __init__(self, obj, weights): # noqa: F811
def __init__(self, obj, weights):
"""
Create a Weighted object
Expand Down

0 comments on commit bc5c79e

Please sign in to comment.