From ea5b6eb086a79f19f32352c832680ff81188df92 Mon Sep 17 00:00:00 2001 From: jorenham Date: Sat, 21 Dec 2024 01:08:20 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20`stats`:=20casually=20invent=20a=20?= =?UTF-8?q?way=20to=20do=20overload=20attributes=20(it=20even=20works=20on?= =?UTF-8?q?=20mypy)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stats/_distribution_infrastructure.pyi | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/scipy-stubs/stats/_distribution_infrastructure.pyi b/scipy-stubs/stats/_distribution_infrastructure.pyi index 7c15ebd1..30cb4911 100644 --- a/scipy-stubs/stats/_distribution_infrastructure.pyi +++ b/scipy-stubs/stats/_distribution_infrastructure.pyi @@ -2,8 +2,8 @@ # pyright: reportUnannotatedClassAttribute=false import abc -from collections.abc import Mapping, Sequence, Set as AbstractSet -from typing import Any, ClassVar, Final, Generic, Literal as L, TypeAlias, overload +from collections.abc import Callable, Mapping, Sequence, Set as AbstractSet +from typing import Any, ClassVar, Final, Generic, Literal as L, Protocol, TypeAlias, overload, type_check_only from typing_extensions import LiteralString, Never, Self, TypeVar, override import numpy as np @@ -61,6 +61,24 @@ _DrawProportions: TypeAlias = tuple[onp.ToFloat, onp.ToFloat, onp.ToFloat, onp.T _CDist: TypeAlias = ContinuousDistribution[np.floating[Any], _ShapeT0] _CDist0: TypeAlias = ContinuousDistribution[_FloatingT, tuple[()]] +@type_check_only +class _ParameterField(Protocol[_FloatingT_co, _ShapeT0_co]): + # This actually works (even on mypy)! + @overload + def __get__( + self: _ParameterField[_FloatingT, tuple[()]], + instance: object, + owner: type | None = None, + /, + ) -> _FloatingT: ... + @overload + def __get__( + self: _ParameterField[_FloatingT, _ShapeT1], + instance: object, + owner: type | None = None, + /, + ) -> onp.ArrayND[_FloatingT, _ShapeT1]: ... + ### _null: Final[_Null] = ... @@ -233,7 +251,7 @@ class ContinuousDistribution(_BaseDistribution[_FloatingT_co, _ShapeT0_co], Gene def __sub__(self, lshift: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ... def __mul__(self, scale: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ... def __truediv__(self, iscale: onp.ToFloat, /) -> ShiftedScaledDistribution[Self, _FloatingT_co, _ShapeT0_co]: ... - def __pow__(self, exp: onp.ToInt, /) -> MonotonicTransformedDistribution[Self, _FloatingT_co, _ShapeT0_co]: ... + def __pow__(self, exp: onp.ToInt, /) -> MonotonicTransformedDistribution[Self, _ShapeT0_co]: ... __radd__ = __add__ __rsub__ = __sub__ __rmul__ = __mul__ @@ -284,7 +302,7 @@ class ContinuousDistribution(_BaseDistribution[_FloatingT_co, _ShapeT0_co], Gene @overload def llf(self, sample: onp.ToFloat | onp.ToFloatND, /, *, axis: AnyShape | None = -1) -> _Float | onp.ArrayND[_Float]: ... - # +_ElementwiseFunction: TypeAlias = Callable[[onp.ArrayND[np.float64]], onp.ArrayND[_FloatingT]] # 7 years of asking and >400 upvotes, but still no higher-kinded typing support: https://github.com/python/typing/issues/548 class TransformedDistribution( @@ -302,18 +320,38 @@ class TransformedDistribution( ) -> None: ... class MonotonicTransformedDistribution( - TransformedDistribution[_CDistT_co, _FloatingT_co, _ShapeT0_co], - Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co], + TransformedDistribution[_CDistT_co, np.float64, _ShapeT0_co], + Generic[_CDistT_co, _ShapeT0_co], ): - # TODO(jorenham) - ... + _g: Final[_ElementwiseFunction] + _h: Final[_ElementwiseFunction] + _dh: Final[_ElementwiseFunction] + _logdh: Final[_ElementwiseFunction] + _increasing: Final[bool] + _repr_pattern: Final[str] + + def __init__( + self: MonotonicTransformedDistribution[_CDist[_ShapeT0], _ShapeT0], + X: _CDistT_co, + /, + *args: Never, + g: _ElementwiseFunction, + h: _ElementwiseFunction, + dh: _ElementwiseFunction, + logdh: _ElementwiseFunction | None = None, + increasing: bool = True, + repr_pattern: str | None = None, + tol: opt.Just[float] | _Null = ..., + validation_policy: _ValidationPolicy = None, + cache_policy: _CachePolicy = None, + ) -> None: ... class TruncatedDistribution( TransformedDistribution[_CDistT_co, _FloatingT_co, _ShapeT0_co], Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co], ): - lb: _FloatingT_co | onp.ArrayND[_FloatingT_co, _ShapeT0_co] - ub: _FloatingT_co | onp.ArrayND[_FloatingT_co, _ShapeT0_co] + lb: _ParameterField[_FloatingT_co, _ShapeT0_co] + ub: _ParameterField[_FloatingT_co, _ShapeT0_co] @overload def __init__(