Skip to content

Commit

Permalink
special: type-test some ufuncs, and dot some "i"'s
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Dec 24, 2024
1 parent a991b68 commit d5cbfd6
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ warn_unreachable = true
warn_unused_ignores = true
disallow_any_explicit = false # no other way to type e.g. `float64 <: number[Any]`
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
# plugins = ["numpy.typing.mypy_plugin"]
plugins = ["numpy.typing.mypy_plugin"]

[tool.pyright]
include = ["codegen", "scipy-stubs", "tests"]
Expand Down
21 changes: 18 additions & 3 deletions scipy-stubs/special/_ufuncs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from typing_extensions import LiteralString, Never, TypeAliasType, TypeVar, Unpa
import numpy as np
import optype as op
import optype.numpy as onp
import optype.typing as opt
from scipy._typing import AnyShape, Casting, EnterNoneMixin, OrderKACF

# NOTE: The only way I was able to get `ComplexWarning` to work on both numpy<1.25 and >=1.25, and is needed until `scipy>=1.16`
Expand Down Expand Up @@ -325,14 +326,16 @@ _CoFloat64ND: TypeAlias = onp.ArrayND[_CoFloat64]
_CoComplex128ND: TypeAlias = onp.ArrayND[_CoComplex128]

_SubFloat: TypeAlias = _Float16 | _CoInt # anything "below" float32 | float64 that isn't float32 | float64
_ToSubFloat: TypeAlias = float | _SubFloat # does not overlap with float32 | float64
_ToSubComplex: TypeAlias = complex | _SubFloat # does not overlap with complex64 | complex128
_ToSubFloat: TypeAlias = opt.Just[float] | int | _SubFloat # does not overlap with float32 | float64
_ToSubFloatND: TypeAlias = _ToND[_SubFloat, opt.Just[float] | int]

_ToSubComplex: TypeAlias = opt.Just[complex] | _ToSubFloat # does not overlap with complex64 | complex128

_CoT = TypeVar("_CoT", bound=np.generic)
_ToT = TypeVar("_ToT")
_ToND: TypeAlias = onp.CanArrayND[_CoT] | onp.SequenceND[onp.CanArrayND[_CoT]] | onp.SequenceND[_ToT]

_ToFloat32 = TypeAliasType("_ToFloat32", float | _Float32 | _SubFloat)
_ToFloat32 = TypeAliasType("_ToFloat32", int | _Float32 | _SubFloat)
_ToFloat64 = TypeAliasType("_ToFloat64", float | _CoFloat64)
_ToFloat64ND = TypeAliasType("_ToFloat64ND", _ToND[_CoFloat64, _ToFloat64])
_ToFloat64_D: TypeAlias = _ToFloat64 | _ToFloat64ND
Expand Down Expand Up @@ -777,8 +780,12 @@ class _UFunc11f(_UFunc11[_NameT_co, _IdentityT_co], Generic[_NameT_co, _Identity
def types(self, /) -> list[L["f->f", "d->d"]]: ...
#
@overload
def __call__(self, x: _ToSubFloat, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float64: ...
@overload
def __call__(self, x: _ToSubFloat, /, out: _Out1 = None, **kw: Unpack[_Kw11f]) -> _Float: ...
@overload
def __call__(self, x: _ToSubFloatND, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float64ND: ...
@overload
def __call__(self, x: _Float_DT, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float_DT: ...
@overload
def __call__(self, x: _ToFloat64ND, /, out: _Out1 = None, **kw: Unpack[_Kw11f]) -> _FloatND: ...
Expand All @@ -799,8 +806,12 @@ class _UFunc11g(_UFunc11[_NameT_co, _IdentityT_co], Generic[_NameT_co, _Identity
def types(self, /) -> list[L["f->f", "d->d", "g->g"]]: ...
#
@overload
def __call__(self, x: _ToSubFloat, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float64: ...
@overload
def __call__(self, x: _ToSubFloat, /, out: _Out1 = None, **kw: Unpack[_Kw11g]) -> _LFloat: ...
@overload
def __call__(self, x: _ToSubFloatND, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float64ND: ...
@overload
def __call__(self, x: _LFloat_DT, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _LFloat_DT: ...
@overload
def __call__(self, x: onp.ToFloatND, /, out: _Out1 = None, **kw: Unpack[_Kw11g]) -> _LFloatND: ...
Expand Down Expand Up @@ -843,8 +854,12 @@ class _UFunc11fc(_UFunc11[_NameT_co, _IdentityT_co], Generic[_NameT_co, _Identit
def types(self, /) -> list[L["f->f", "d->d", "F->F", "D->D"]]: ...
#
@overload
def __call__(self, x: opt.Just[float] | opt.JustInt, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Float64: ...
@overload
def __call__(self, x: _ToSubFloat, /, out: _Out1 = None, **kw: Unpack[_Kw11fc]) -> _Float: ...
@overload
def __call__(self, x: opt.Just[complex], /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Complex128: ...
@overload
def __call__(self, x: _ToSubComplex, /, out: _Out1 = None, **kw: Unpack[_Kw11fc]) -> _Inexact: ...
@overload
def __call__(self, x: _Inexact_DT, /, out: _Out1 = None, **kw: Unpack[_KwBase]) -> _Inexact_DT: ...
Expand Down
115 changes: 115 additions & 0 deletions tests/special/test_ufuncs.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Any, Literal as L, TypeAlias
from typing_extensions import assert_type

import numpy as np
import optype.numpy as onp
import scipy.special as sp

_Float32ND: TypeAlias = onp.ArrayND[np.float32]
_Float64ND: TypeAlias = onp.ArrayND[np.float64]
_LongDoubleND: TypeAlias = onp.ArrayND[np.longdouble]
_Complex64ND: TypeAlias = onp.ArrayND[np.complex64]
_Complex128ND: TypeAlias = onp.ArrayND[np.complex128]

_b1: np.bool_
_i: np.integer[Any]
_f: np.floating[Any]
_f2: np.float16
_f4: np.float32
_f8: np.float64
_g: np.longdouble
_c8: np.complex64
_c16: np.complex128

_b1_nd: onp.ArrayND[np.bool_]
_i1_nd: onp.ArrayND[np.uint8 | np.int8]
_f2_nd: onp.ArrayND[np.float16]
_f4_nd: _Float32ND
_f8_nd: _Float64ND
_g_nd: _LongDoubleND
_c8_nd: _Complex64ND
_c16_nd: _Complex128ND

# _UFunc
assert_type(sp.cbrt.__name__, L["cbrt"])
assert_type(sp.cbrt.identity, L[0])

# _UFunc11
assert_type(sp.cbrt.nin, L[1])
assert_type(sp.cbrt.nout, L[1])
assert_type(sp.cbrt.nargs, L[2])
assert_type(sp.cbrt.ntypes, L[2])
assert_type(sp.cbrt.types, list[L["f->f", "d->d"]])
assert_type(sp.exprel.identity, None)

# _UFunc11f
assert_type(sp.cbrt(_b1), np.float64)
assert_type(sp.cbrt(_b1_nd), _Float64ND)
assert_type(sp.cbrt(_i), np.float64)
assert_type(sp.cbrt(_i1_nd), _Float64ND)
assert_type(sp.cbrt(_f2), np.float64)
assert_type(sp.cbrt(_f2_nd), _Float64ND)
assert_type(sp.cbrt(_f4), np.float32)
assert_type(sp.cbrt(_f4_nd), _Float32ND)
assert_type(sp.cbrt(_f8), np.float64)
assert_type(sp.cbrt(_f8_nd), _Float64ND)
sp.cbrt(_c16) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
sp.cbrt(_c16_nd) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
assert_type(sp.cbrt(False), np.float64)
assert_type(sp.cbrt([False]), _Float64ND)
assert_type(sp.cbrt(0), np.float64)
assert_type(sp.cbrt([0]), _Float64ND)
assert_type(sp.cbrt(0.0), np.float64)
assert_type(sp.cbrt([0.0]), _Float64ND)
sp.cbrt(0j) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
sp.cbrt([0j]) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
assert_type(sp.cbrt.at(_b1_nd, _i), None)
assert_type(sp.cbrt.at(_f8_nd, _i), None)
sp.cbrt.at(_c16, _i) # type:ignore[arg-type] # pyright: ignore[reportArgumentType]

# _UFunc11g
assert_type(sp.logit.ntypes, L[3])
assert_type(sp.logit(_b1), np.float64)
assert_type(sp.logit(_b1_nd), _Float64ND)
assert_type(sp.logit(_f4), np.float32)
assert_type(sp.logit(_f4_nd), _Float32ND)
assert_type(sp.logit(_f8), np.float64)
assert_type(sp.logit(_f8_nd), _Float64ND)
assert_type(sp.logit(_g), np.longdouble)
assert_type(sp.logit(_g_nd), _LongDoubleND)
sp.logit(_c16) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
sp.logit(_c16_nd) # type:ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
assert_type(sp.logit(0), np.float64)
assert_type(sp.logit([0]), _Float64ND)
assert_type(sp.logit(0.0), np.float64)
assert_type(sp.logit([0.0]), _Float64ND)
assert_type(sp.logit.at(_b1_nd, _i), None)
assert_type(sp.logit.at(_f8_nd, _i), None)
assert_type(sp.logit.at(_g_nd, _i), None)
sp.logit.at(_c16, _i) # type:ignore[arg-type] # pyright: ignore[reportArgumentType]

# _UFunc11c - TODO: wofz
# _UFunc11fc - TODO: erf

# _UFunc12 - TODO
# _UFunc14 - TODO

###

# _UFunc21 - TODO
# _UFunc22 - TODO
# _UFunc24 - TODO

###

# _UFunc31 - TODO
# _UFunc32 - TODO

###

# _UFunc41 - TODO + sph_harm deprecation test
# _UFunc42 - TODO

###

# _UFunc52 - TODO

0 comments on commit d5cbfd6

Please sign in to comment.