Skip to content

Commit

Permalink
ENH: make ndarray generic over dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
person142 committed Mar 31, 2020
1 parent ba67281 commit c753c41
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 38 deletions.
89 changes: 76 additions & 13 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from typing import (
ByteString,
Container,
Dict,
Generic,
IO,
Iterable,
List,
Expand Down Expand Up @@ -283,7 +284,16 @@ class _ArrayOrScalarCommon(

_BufferType = Union[ndarray, bytes, bytearray, memoryview]

class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
_ArbitraryDtype = TypeVar('_ArbitraryDtype', bound=generic)
_ArrayDtype = TypeVar('_ArrayDtype', bound=generic)

class ndarray(
Generic[_ArrayDtype],
_ArrayOrScalarCommon,
Iterable,
Sized,
Container,
):
real: ndarray
imag: ndarray
def __new__(
Expand All @@ -296,7 +306,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
order: Optional[str] = ...,
) -> ndarray: ...
@property
def dtype(self) -> _Dtype: ...
def dtype(self) -> Type[_ArrayDtype]: ...
@property
def ctypes(self) -> _ctypes: ...
@property
Expand Down Expand Up @@ -326,6 +336,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
) -> None: ...
def dump(self, file: str) -> None: ...
def dumps(self) -> bytes: ...
@overload
def astype(
self,
dtype: _ArbitraryDtype,
order: str = ...,
casting: str = ...,
subok: bool = ...,
copy: bool = ...,
) -> ndarray[_ArbitraryDtype]: ...
@overload
def astype(
self,
dtype: _DtypeLike,
Expand All @@ -334,40 +354,74 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
subok: bool = ...,
copy: bool = ...,
) -> ndarray: ...
def byteswap(self, inplace: bool = ...) -> ndarray: ...
def copy(self, order: str = ...) -> ndarray: ...
def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ...
@overload
def copy(self) -> ndarray[_ArrayDtype]: ...
@overload
def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
@overload
def view(self) -> ndarray[_ArrayDtype]: ...
@overload
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ...
@overload
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
@overload
def view(
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
self,
dtype: _ArbitraryDtype,
type: Type[_NdArraySubClass],
) -> _NdArraySubClass[_ArbitraryDtype]: ...
@overload
def view(
self,
dtype: _DtypeLike,
type: Type[_NdArraySubClass],
) -> _NdArraySubClass: ...
@overload
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def getfield(
self,
dtype: Type[_ArbitraryDtype],
offset: int = ...,
) -> ndarray[_ArbitraryDtype]: ...
@overload
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
def setflags(
self, write: bool = ..., align: bool = ..., uic: bool = ...
) -> None: ...
def fill(self, value: Any) -> None: ...
# Shape manipulation
@overload
def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ...
def reshape(
self,
shape: Sequence[int],
*,
order: str = ...,
) -> ndarray[_ArrayDtype]: ...
@overload
def reshape(self, *shape: int, order: str = ...) -> ndarray: ...
def reshape(
self,
*shape: int,
order: str = ...,
) -> ndarray[_ArrayDtype]: ...
@overload
def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ...
@overload
def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ...
@overload
def transpose(self, axes: Sequence[int]) -> ndarray: ...
def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ...
@overload
def transpose(self, *axes: int) -> ndarray: ...
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
def flatten(self, order: str = ...) -> ndarray: ...
def ravel(self, order: str = ...) -> ndarray: ...
def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ...
def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ...
def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ...
def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
def squeeze(
self,
axis: Union[int, Tuple[int, ...]] = ...,
) -> ndarray[_ArrayDtype]: ...
# Many of these special methods are irrelevant currently, since protocols
# aren't supported yet. That said, I'm adding them for completeness.
# https://docs.python.org/3/reference/datamodel.html
Expand Down Expand Up @@ -472,6 +526,15 @@ class str_(character): ...
# float128, complex256
# float96

@overload
def array(
object: object,
dtype: Type[_ArbitraryDtype] = ...,
copy: bool = ...,
subok: bool = ...,
ndmin: int = ...,
) -> ndarray[_ArbitraryDtype]: ...
@overload
def array(
object: object,
dtype: _DtypeLike = ...,
Expand Down
2 changes: 1 addition & 1 deletion tests/fail/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
# https://github.com/numpy/numpy-stubs/issues/7
#
# for more context.
float_array = np.array([1.0])
float_array = np.array([1.0], dtype=np.float64)
float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only
2 changes: 1 addition & 1 deletion tests/pass/ndarray_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# item
nd.item() # `nd` should be one-element in runtime
Expand Down
2 changes: 1 addition & 1 deletion tests/pass/ndarray_shape_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# reshape
nd.reshape()
Expand Down
8 changes: 5 additions & 3 deletions tests/pass/simple.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Simple expression that should pass with mypy."""
import operator
from typing import TypeVar

import numpy as np
from typing import Iterable # noqa: F401

# Basic checks
array = np.array([1, 2])
array = np.array([1, 2], dtype=np.int64)
T = TypeVar('T', bound=np.generic)
def ndarray_func(x):
# type: (np.ndarray) -> np.ndarray
# type: (np.ndarray[T]) -> np.ndarray[T]
return x
ndarray_func(np.array([1, 2]))
array == 1
Expand Down Expand Up @@ -70,7 +72,7 @@ def iterable_func(x):
# Other special methods
len(array)
str(array)
array_scalar = np.array(1)
array_scalar = np.array(1, dtype=np.int64)
int(array_scalar)
float(array_scalar)
# currently does not work due to https://github.com/python/typeshed/issues/1904
Expand Down
2 changes: 1 addition & 1 deletion tests/pass/simple_py3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

array = np.array([1, 2])
array = np.array([1, 2], dtype=np.int64)

# The @ operator is not in python 2
array @ array
38 changes: 21 additions & 17 deletions tests/reveal/ndarray_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# dtype of the array
reveal_type(nd) # E: numpy.ndarray[numpy.int64*]

# item
reveal_type(nd.item()) # E: Any
Expand All @@ -19,36 +22,37 @@
# dumps is pretty simple

# astype
reveal_type(nd.astype("float")) # E: numpy.ndarray
reveal_type(nd.astype(float)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any]
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any]

# byteswap
reveal_type(nd.byteswap()) # E: numpy.ndarray
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*]

# copy
reveal_type(nd.copy()) # E: numpy.ndarray
reveal_type(nd.copy("C")) # E: numpy.ndarray
reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*]

# view
class SubArray(np.ndarray):
pass

reveal_type(nd.view()) # E: numpy.ndarray
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*]
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*]
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
# FIXME: get subclasses working correctly
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray

# getfield
reveal_type(nd.getfield("float")) # E: numpy.ndarray
reveal_type(nd.getfield(float)) # E: numpy.ndarray
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any]
reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*]

# setflags does not return a value
# fill does not return a value

2 changes: 1 addition & 1 deletion tests/reveal/ndarray_shape_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

nd = np.array([[1, 2], [3, 4]])
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)

# reshape
reveal_type(nd.reshape()) # E: numpy.ndarray
Expand Down

0 comments on commit c753c41

Please sign in to comment.