From 858909f323d5e3887f7354cea64c0b813c8ed577 Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Sun, 29 Mar 2020 12:55:51 -0700 Subject: [PATCH] ENH: make ndarray generic over dtype Closes https://github.com/numpy/numpy-stubs/issues/7. --- numpy-stubs/__init__.pyi | 68 +++++++++++++++++----- tests/fail/ndarray.py | 2 +- tests/pass/ndarray_conversion.py | 2 +- tests/pass/ndarray_shape_manipulation.py | 2 +- tests/pass/simple.py | 8 ++- tests/pass/simple_py3.py | 2 +- tests/reveal/ndarray_conversion.py | 37 ++++++------ tests/reveal/ndarray_shape_manipulation.py | 2 +- 8 files changed, 85 insertions(+), 38 deletions(-) diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index 0847317..59d4d05 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -8,6 +8,7 @@ from typing import ( ByteString, Container, Dict, + Generic, IO, Iterable, List, @@ -283,7 +284,10 @@ class _ArrayOrScalarCommon( _BufferType = Union[ndarray, bytes, bytearray, memoryview] -class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): +_ArbitraryDtype = TypeVar("_ArbitraryDtype", bound=generic) +_ArrayDtype = TypeVar("_ArrayDtype", bound=generic) + +class ndarray(Generic[_ArrayDtype], _ArrayOrScalarCommon, Iterable, Sized, Container): real: ndarray imag: ndarray def __new__( @@ -296,7 +300,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): order: Optional[str] = ..., ) -> ndarray: ... @property - def dtype(self) -> _Dtype: ... + def dtype(self) -> Type[_ArrayDtype]: ... @property def ctypes(self) -> _ctypes: ... @property @@ -326,6 +330,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): ) -> None: ... def dump(self, file: str) -> None: ... def dumps(self) -> bytes: ... + @overload + def astype( + self, + dtype: _ArbitraryDtype, + order: str = ..., + casting: str = ..., + subok: bool = ..., + copy: bool = ..., + ) -> ndarray[_ArbitraryDtype]: ... + @overload def astype( self, dtype: _DtypeLike, @@ -334,18 +348,34 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): subok: bool = ..., copy: bool = ..., ) -> ndarray: ... - def byteswap(self, inplace: bool = ...) -> ndarray: ... - def copy(self, order: str = ...) -> ndarray: ... + def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ... + @overload + def copy(self) -> ndarray[_ArrayDtype]: ... + @overload + def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ... + @overload + def view(self) -> ndarray[_ArrayDtype]: ... @overload def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ... @overload + def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ... + @overload def view(self, dtype: _DtypeLike = ...) -> ndarray: ... @overload + def view( + self, dtype: _ArbitraryDtype, type: Type[_NdArraySubClass] + ) -> _NdArraySubClass[_ArbitraryDtype]: ... + @overload def view( self, dtype: _DtypeLike, type: Type[_NdArraySubClass] ) -> _NdArraySubClass: ... @overload def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ... + @overload + def getfield( + self, dtype: Type[_ArbitraryDtype], offset: int = ... + ) -> ndarray[_ArbitraryDtype]: ... + @overload def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ... def setflags( self, write: bool = ..., align: bool = ..., uic: bool = ... @@ -353,21 +383,25 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): def fill(self, value: Any) -> None: ... # Shape manipulation @overload - def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ... + def reshape( + self, shape: Sequence[int], *, order: str = ... + ) -> ndarray[_ArrayDtype]: ... @overload - def reshape(self, *shape: int, order: str = ...) -> ndarray: ... + def reshape(self, *shape: int, order: str = ...) -> ndarray[_ArrayDtype]: ... @overload def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ... @overload def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ... @overload - def transpose(self, axes: Sequence[int]) -> ndarray: ... + def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ... @overload - def transpose(self, *axes: int) -> ndarray: ... - def swapaxes(self, axis1: int, axis2: int) -> ndarray: ... - def flatten(self, order: str = ...) -> ndarray: ... - def ravel(self, order: str = ...) -> ndarray: ... - def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ... + def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ... + def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ... + def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ... + def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ... + def squeeze( + self, axis: Union[int, Tuple[int, ...]] = ... + ) -> ndarray[_ArrayDtype]: ... # Many of these special methods are irrelevant currently, since protocols # aren't supported yet. That said, I'm adding them for completeness. # https://docs.python.org/3/reference/datamodel.html @@ -471,7 +505,15 @@ class str_(character): ... # uint_, int_, float_, complex_ # float128, complex256 # float96 - +@overload +def array( + object: object, + dtype: Type[_ArbitraryDtype] = ..., + copy: bool = ..., + subok: bool = ..., + ndmin: int = ..., +) -> ndarray[_ArbitraryDtype]: ... +@overload def array( object: object, dtype: _DtypeLike = ..., diff --git a/tests/fail/ndarray.py b/tests/fail/ndarray.py index 5a5130d..17fee42 100644 --- a/tests/fail/ndarray.py +++ b/tests/fail/ndarray.py @@ -7,5 +7,5 @@ # https://github.com/numpy/numpy-stubs/issues/7 # # for more context. -float_array = np.array([1.0]) +float_array = np.array([1.0], dtype=np.float64) float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only diff --git a/tests/pass/ndarray_conversion.py b/tests/pass/ndarray_conversion.py index 21b71e1..e4f6b6e 100644 --- a/tests/pass/ndarray_conversion.py +++ b/tests/pass/ndarray_conversion.py @@ -1,6 +1,6 @@ import numpy as np -nd = np.array([[1, 2], [3, 4]]) +nd = np.array([[1, 2], [3, 4]], dtype=np.int64) # item nd.item() # `nd` should be one-element in runtime diff --git a/tests/pass/ndarray_shape_manipulation.py b/tests/pass/ndarray_shape_manipulation.py index e18e407..fcccebd 100644 --- a/tests/pass/ndarray_shape_manipulation.py +++ b/tests/pass/ndarray_shape_manipulation.py @@ -1,6 +1,6 @@ import numpy as np -nd = np.array([[1, 2], [3, 4]]) +nd = np.array([[1, 2], [3, 4]], dtype=np.int64) # reshape nd.reshape() diff --git a/tests/pass/simple.py b/tests/pass/simple.py index 6c29de9..a13db0f 100644 --- a/tests/pass/simple.py +++ b/tests/pass/simple.py @@ -1,13 +1,15 @@ """Simple expression that should pass with mypy.""" import operator +from typing import TypeVar import numpy as np from typing import Iterable # noqa: F401 # Basic checks -array = np.array([1, 2]) +array = np.array([1, 2], dtype=np.int64) +T = TypeVar('T', bound=np.generic) def ndarray_func(x): - # type: (np.ndarray) -> np.ndarray + # type: (np.ndarray[T]) -> np.ndarray[T] return x ndarray_func(np.array([1, 2])) array == 1 @@ -70,7 +72,7 @@ def iterable_func(x): # Other special methods len(array) str(array) -array_scalar = np.array(1) +array_scalar = np.array(1, dtype=np.int64) int(array_scalar) float(array_scalar) # currently does not work due to https://github.com/python/typeshed/issues/1904 diff --git a/tests/pass/simple_py3.py b/tests/pass/simple_py3.py index c05a1ce..46f55af 100644 --- a/tests/pass/simple_py3.py +++ b/tests/pass/simple_py3.py @@ -1,6 +1,6 @@ import numpy as np -array = np.array([1, 2]) +array = np.array([1, 2], dtype=np.int64) # The @ operator is not in python 2 array @ array diff --git a/tests/reveal/ndarray_conversion.py b/tests/reveal/ndarray_conversion.py index 1e17d44..9a50e31 100644 --- a/tests/reveal/ndarray_conversion.py +++ b/tests/reveal/ndarray_conversion.py @@ -1,6 +1,9 @@ import numpy as np -nd = np.array([[1, 2], [3, 4]]) +nd = np.array([[1, 2], [3, 4]], dtype=np.int64) + +# dtype of the array +reveal_type(nd) # E: numpy.ndarray[numpy.int64*] # item reveal_type(nd.item()) # E: Any @@ -19,36 +22,36 @@ # dumps is pretty simple # astype -reveal_type(nd.astype("float")) # E: numpy.ndarray -reveal_type(nd.astype(float)) # E: numpy.ndarray -reveal_type(nd.astype(float, "K")) # E: numpy.ndarray -reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray -reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray -reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray +reveal_type(nd.astype("float")) # E: numpy.ndarray[Any] +reveal_type(nd.astype(float)) # E: numpy.ndarray[Any] +reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any] +reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any] +reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any] +reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any] # byteswap -reveal_type(nd.byteswap()) # E: numpy.ndarray -reveal_type(nd.byteswap(True)) # E: numpy.ndarray +reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*] +reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*] # copy -reveal_type(nd.copy()) # E: numpy.ndarray -reveal_type(nd.copy("C")) # E: numpy.ndarray +reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*] +reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*] # view class SubArray(np.ndarray): pass -reveal_type(nd.view()) # E: numpy.ndarray -reveal_type(nd.view(np.int64)) # E: numpy.ndarray +reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*] +reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*] # replace `Any` with `numpy.matrix` when `matrix` will be added to stubs reveal_type(nd.view(np.int64, np.matrix)) # E: Any reveal_type(nd.view(np.int64, SubArray)) # E: SubArray # getfield -reveal_type(nd.getfield("float")) # E: numpy.ndarray -reveal_type(nd.getfield(float)) # E: numpy.ndarray -reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray +reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any] +reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any] +reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any] +reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*] # setflags does not return a value # fill does not return a value - diff --git a/tests/reveal/ndarray_shape_manipulation.py b/tests/reveal/ndarray_shape_manipulation.py index a44e1cf..c93cb21 100644 --- a/tests/reveal/ndarray_shape_manipulation.py +++ b/tests/reveal/ndarray_shape_manipulation.py @@ -1,6 +1,6 @@ import numpy as np -nd = np.array([[1, 2], [3, 4]]) +nd = np.array([[1, 2], [3, 4]], dtype=np.int64) # reshape reveal_type(nd.reshape()) # E: numpy.ndarray