From a890c3f5a36fa38ff9546e6d334570b41be248ea Mon Sep 17 00:00:00 2001 From: Avasam Date: Sun, 20 Oct 2024 14:52:58 -0400 Subject: [PATCH 1/4] Annotate `scipy.fft._realtransforms.dct` --- scipy-stubs/fft/_realtransforms.pyi | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/scipy-stubs/fft/_realtransforms.pyi b/scipy-stubs/fft/_realtransforms.pyi index 4d9cdd89..5cbdc726 100644 --- a/scipy-stubs/fft/_realtransforms.pyi +++ b/scipy-stubs/fft/_realtransforms.pyi @@ -1,3 +1,6 @@ +from numpy import float64, generic +from numpy.typing import NDArray + from scipy._typing import Untyped def dctn( @@ -42,15 +45,15 @@ def idstn( orthogonalize: Untyped | None = None, ) -> Untyped: ... def dct( - x: Untyped, - type: int = 2, - n: Untyped | None = None, + x: NDArray[generic], + type: Literal[1, 2, 3, 4] = 2, + n: int | None = None, axis: int = -1, - norm: Untyped | None = None, + norm: Literal["backward", "ortho", "forward"] | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, -) -> Untyped: ... + workers: int | None = None, + orthogonalize: bool | None = None, +) -> NDArray[float64]: ... def idct( x: Untyped, type: int = 2, From 47b0c6e9b8a7861f29f92bf3c8ea040321d3faf7 Mon Sep 17 00:00:00 2001 From: Avasam Date: Sun, 20 Oct 2024 15:18:11 -0400 Subject: [PATCH 2/4] Add type aliases for "DCT type" and "normalization mode" --- scipy-stubs/_typing.pyi | 2 ++ scipy-stubs/fft/_realtransforms.pyi | 35 ++++++++++++++--------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/scipy-stubs/_typing.pyi b/scipy-stubs/_typing.pyi index 3d6dba15..44592345 100644 --- a/scipy-stubs/_typing.pyi +++ b/scipy-stubs/_typing.pyi @@ -93,6 +93,8 @@ CorrelateMode: TypeAlias = Literal["valid", "same", "full"] # scipy literals NanPolicy: TypeAlias = Literal["raise", "propagate", "omit"] Alternative: TypeAlias = Literal["two-sided", "less", "greater"] +DCTType: TypeAlias = Literal[1, 2, 3, 4] +NormalizationMode: TypeAlias = Literal["backward", "ortho", "forward"] # used in `scipy.linalg.blas` and `scipy.linalg.lapack` @type_check_only diff --git a/scipy-stubs/fft/_realtransforms.pyi b/scipy-stubs/fft/_realtransforms.pyi index 5cbdc726..7600ce46 100644 --- a/scipy-stubs/fft/_realtransforms.pyi +++ b/scipy-stubs/fft/_realtransforms.pyi @@ -1,14 +1,13 @@ from numpy import float64, generic from numpy.typing import NDArray - -from scipy._typing import Untyped +from scipy._typing import DCTType, NormalizationMode, Untyped def dctn( x: Untyped, - type: int = 2, + type: DCTType = 2, s: Untyped | None = None, axes: Untyped | None = None, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, *, @@ -16,70 +15,70 @@ def dctn( ) -> Untyped: ... def idctn( x: Untyped, - type: int = 2, + type: DCTType = 2, s: Untyped | None = None, axes: Untyped | None = None, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, ) -> Untyped: ... def dstn( x: Untyped, - type: int = 2, + type: DCTType = 2, s: Untyped | None = None, axes: Untyped | None = None, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, ) -> Untyped: ... def idstn( x: Untyped, - type: int = 2, + type: DCTType = 2, s: Untyped | None = None, axes: Untyped | None = None, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, ) -> Untyped: ... def dct( x: NDArray[generic], - type: Literal[1, 2, 3, 4] = 2, + type: DCTType = 2, n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: int | None = None, orthogonalize: bool | None = None, ) -> NDArray[float64]: ... def idct( x: Untyped, - type: int = 2, + type: DCTType = 2, n: Untyped | None = None, axis: int = -1, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, ) -> Untyped: ... def dst( x: Untyped, - type: int = 2, + type: DCTType = 2, n: Untyped | None = None, axis: int = -1, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, ) -> Untyped: ... def idst( x: Untyped, - type: int = 2, + type: DCTType = 2, n: Untyped | None = None, axis: int = -1, - norm: Untyped | None = None, + norm: NormalizationMode | None = None, overwrite_x: bool = False, workers: Untyped | None = None, orthogonalize: Untyped | None = None, From 7de0ca0718e73d9946cbd430d4a1559658e6d6a2 Mon Sep 17 00:00:00 2001 From: Avasam Date: Sun, 20 Oct 2024 21:48:07 -0400 Subject: [PATCH 3/4] Address PR comments and fill in other simple types --- scipy-stubs/fft/_realtransforms.pyi | 40 ++++++++++++++++------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/scipy-stubs/fft/_realtransforms.pyi b/scipy-stubs/fft/_realtransforms.pyi index 7600ce46..cc7ab99d 100644 --- a/scipy-stubs/fft/_realtransforms.pyi +++ b/scipy-stubs/fft/_realtransforms.pyi @@ -1,5 +1,6 @@ -from numpy import float64, generic -from numpy.typing import NDArray +import numpy as np +import numpy.typing as npt +from numpy._typing import _ArrayLikeNumber_co from scipy._typing import DCTType, NormalizationMode, Untyped def dctn( @@ -9,9 +10,9 @@ def dctn( axes: Untyped | None = None, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, + workers: int | None = None, *, - orthogonalize: Untyped | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... def idctn( x: Untyped, @@ -20,8 +21,8 @@ def idctn( axes: Untyped | None = None, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... def dstn( x: Untyped, @@ -30,8 +31,8 @@ def dstn( axes: Untyped | None = None, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... def idstn( x: Untyped, @@ -40,11 +41,14 @@ def idstn( axes: Untyped | None = None, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... + +# We could use overloads based on the type of x to get more accurate return type +# see https://github.com/jorenham/scipy-stubs/pull/118#discussion_r1807957439 def dct( - x: NDArray[generic], + x: _ArrayLikeNumber_co, type: DCTType = 2, n: int | None = None, axis: int = -1, @@ -52,7 +56,7 @@ def dct( overwrite_x: bool = False, workers: int | None = None, orthogonalize: bool | None = None, -) -> NDArray[float64]: ... +) -> npt.NDArray[Untyped]: ... def idct( x: Untyped, type: DCTType = 2, @@ -60,8 +64,8 @@ def idct( axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... def dst( x: Untyped, @@ -70,8 +74,8 @@ def dst( axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... def idst( x: Untyped, @@ -80,6 +84,6 @@ def idst( axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False, - workers: Untyped | None = None, - orthogonalize: Untyped | None = None, + workers: int | None = None, + orthogonalize: bool | None = None, ) -> Untyped: ... From eba7dec87e86a8e0f7e5cdd2fa42c61138653fca Mon Sep 17 00:00:00 2001 From: Avasam Date: Sun, 20 Oct 2024 21:54:32 -0400 Subject: [PATCH 4/4] Leftover import + `n` is also a simple param type --- scipy-stubs/fft/_realtransforms.pyi | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scipy-stubs/fft/_realtransforms.pyi b/scipy-stubs/fft/_realtransforms.pyi index cc7ab99d..70c6386b 100644 --- a/scipy-stubs/fft/_realtransforms.pyi +++ b/scipy-stubs/fft/_realtransforms.pyi @@ -1,4 +1,3 @@ -import numpy as np import numpy.typing as npt from numpy._typing import _ArrayLikeNumber_co from scipy._typing import DCTType, NormalizationMode, Untyped @@ -60,7 +59,7 @@ def dct( def idct( x: Untyped, type: DCTType = 2, - n: Untyped | None = None, + n: int | None = None, axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False, @@ -70,7 +69,7 @@ def idct( def dst( x: Untyped, type: DCTType = 2, - n: Untyped | None = None, + n: int | None = None, axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False, @@ -80,7 +79,7 @@ def dst( def idst( x: Untyped, type: DCTType = 2, - n: Untyped | None = None, + n: int | None = None, axis: int = -1, norm: NormalizationMode | None = None, overwrite_x: bool = False,