Skip to content

Commit

Permalink
👽️ stats: 1.15.0 annotate order_statistic and make_distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Dec 21, 2024
1 parent 3546561 commit 3160f1d
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 97 deletions.
2 changes: 0 additions & 2 deletions .mypyignore-todo
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
scipy\.stats\.__all__
scipy\.stats\.(Normal|Uniform)
scipy\.stats\.(_distribution_infrastructure\.)?(make_distribution|order_statistic)
scipy\.stats\._distribution_infrastructure\.(Folded|OrderStatistic)Distribution\.__init__
8 changes: 3 additions & 5 deletions scipy-stubs/stats/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ from ._bws_test import bws_test
from ._censored_data import CensoredData
from ._correlation import chatterjeexi
from ._covariance import Covariance

# TODO(jorenham)
from ._distribution_infrastructure import Mixture, abs, exp, log, truncate # make_distribution, order_statistic
from ._distribution_infrastructure import Mixture, abs, exp, log, make_distribution, order_statistic, truncate
from ._entropy import differential_entropy, entropy
from ._fit import fit, goodness_of_fit
from ._hypotests import (
Expand Down Expand Up @@ -470,7 +468,7 @@ __all__ = [
"logser",
"loguniform",
"lomax",
# "make_distribution",
"make_distribution",
"mannwhitneyu",
"matrix_normal",
"maxwell",
Expand Down Expand Up @@ -506,7 +504,7 @@ __all__ = [
"normaltest",
"norminvgauss",
"obrientransform",
# "order_statistic",
"order_statistic",
"ortho_group",
"page_trend_test",
"pareto",
Expand Down
172 changes: 143 additions & 29 deletions scipy-stubs/stats/_distribution_infrastructure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import abc
from collections.abc import Callable, Mapping, Sequence, Set as AbstractSet
from typing import Any, ClassVar, Final, Generic, Literal as L, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import LiteralString, Never, Self, TypeVar, override
from typing_extensions import LiteralString, Never, Self, TypeIs, TypeVar, override

import numpy as np
import optype as op
import optype.numpy as onp
import optype.typing as opt
from scipy._typing import AnyShape, ToRNG
from ._distn_infrastructure import rv_continuous
from ._probability_distribution import _BaseDistribution

# TODO:
Expand Down Expand Up @@ -86,6 +87,8 @@ class _ParameterField(Protocol[_FloatingT_co, _ShapeT0_co]):

_null: Final[_Null] = ...

def _isnull(x: object) -> TypeIs[_Null | None]: ...

# TODO(jorenham): Generic dtype and shape
class _Domain(abc.ABC):
# NOTE: This is a `ClassVar[dict[str, float]]` that's overridden as instance attribute in `_SimpleDomain`.
Expand All @@ -100,11 +103,7 @@ class _Domain(abc.ABC):
@abc.abstractmethod
def draw(self, /, n: int) -> onp.ArrayND[_FloatingT]: ...
@abc.abstractmethod
def get_numerical_endpoints(
self,
/,
x: _ParamValues,
) -> tuple[onp.ArrayND[_Float], onp.ArrayND[_Float]]: ...
def get_numerical_endpoints(self, /, x: _ParamValues) -> tuple[onp.ArrayND[_Float], onp.ArrayND[_Float]]: ...

# TODO(jorenham): Generic dtype
class _SimpleDomain(_Domain, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -189,9 +188,7 @@ class _Parameterization:
def copy(self, /) -> Self: ...
def matches(self, /, parameters: AbstractSet[str]) -> bool: ...
def validation(
self,
/,
parameter_values: Mapping[str, _Parameter],
self, /, parameter_values: Mapping[str, _Parameter]
) -> tuple[onp.ArrayND[np.bool_], np.dtype[np.floating[Any]]]: ...
def draw(
self,
Expand All @@ -211,6 +208,8 @@ class ContinuousDistribution(_BaseDistribution[_FloatingT_co, _ShapeT0_co], Gene
_not_implemented: Final[str]
_original_parameters: dict[str, _FloatingT_co | onp.ArrayND[_FloatingT_co, _ShapeT0_co]]

_variable: _Parameter

@property
def tol(self, /) -> float | np.float64 | _Null | None: ...
@tol.setter
Expand Down Expand Up @@ -409,6 +408,8 @@ class TransformedDistribution(
ContinuousDistribution[_FloatingT_co, _ShapeT0_co],
Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co],
):
_dist: _CDistT_co # readonly

def __init__(
self: TransformedDistribution[ContinuousDistribution[_FloatingT, _ShapeT0], _FloatingT, _ShapeT0], # nice trick, eh?
X: _CDistT_co,
Expand Down Expand Up @@ -490,21 +491,21 @@ class FoldedDistribution(
) -> None: ...

class TruncatedDistribution(
TransformedDistribution[_CDistT_co, _FloatingT_co, _ShapeT0_co],
Generic[_CDistT_co, _FloatingT_co, _ShapeT0_co],
TransformedDistribution[_CDistT_co, np.floating[Any], _ShapeT0_co],
Generic[_CDistT_co, _ShapeT0_co],
):
_lb_domain: ClassVar[_RealDomain] = ...
_lb_param: ClassVar[_RealParameter] = ...

_ub_domain: ClassVar[_RealDomain] = ...
_ub_param: ClassVar[_RealParameter] = ...

lb: _ParameterField[_FloatingT_co, _ShapeT0_co]
ub: _ParameterField[_FloatingT_co, _ShapeT0_co]
lb: _ParameterField[np.floating[Any], _ShapeT0_co]
ub: _ParameterField[np.floating[Any], _ShapeT0_co]

@overload
def __init__(
self: TruncatedDistribution[_CDistT0, np.floating[Any], tuple[()]],
self: TruncatedDistribution[_CDistT0, tuple[()]],
X: _CDistT0,
/,
*args: Never,
Expand All @@ -516,7 +517,7 @@ class TruncatedDistribution(
) -> None: ...
@overload
def __init__(
self: TruncatedDistribution[_CDistT1, np.floating[Any], tuple[int]],
self: TruncatedDistribution[_CDistT1, tuple[int]],
X: _CDistT1,
/,
*args: Never,
Expand All @@ -528,7 +529,7 @@ class TruncatedDistribution(
) -> None: ...
@overload
def __init__(
self: TruncatedDistribution[_CDistT2, np.floating[Any], tuple[int, int]],
self: TruncatedDistribution[_CDistT2, tuple[int, int]],
X: _CDistT2,
/,
*args: Never,
Expand All @@ -540,7 +541,7 @@ class TruncatedDistribution(
) -> None: ...
@overload
def __init__(
self: TruncatedDistribution[_CDistT3, np.floating[Any], tuple[int, int, int]],
self: TruncatedDistribution[_CDistT3, tuple[int, int, int]],
X: _CDistT3,
/,
*args: Never,
Expand All @@ -552,7 +553,7 @@ class TruncatedDistribution(
) -> None: ...
@overload
def __init__(
self: TruncatedDistribution[_CDistT, np.floating[Any], tuple[int, ...]],
self: TruncatedDistribution[_CDistT, tuple[int, ...]],
X: _CDistT,
/,
*args: Never,
Expand All @@ -563,6 +564,76 @@ class TruncatedDistribution(
cache_policy: _CachePolicy = None,
) -> None: ...

# always float64 or longdouble
class OrderStatisticDistribution(TransformedDistribution[_CDistT_co, _Float, _ShapeT0_co], Generic[_CDistT_co, _ShapeT0_co]):
# these should actually be integral; but the `_IntegerDomain` isn't finished yet
_r_domain: ClassVar[_RealDomain] = ...
_r_param: ClassVar[_RealParameter] = ...

_n_domain: ClassVar[_RealDomain] = ...
_n_param: ClassVar[_RealParameter] = ...

@overload
def __init__(
self: OrderStatisticDistribution[_CDistT0, tuple[()]],
dist: _CDistT0,
/,
*args: Never,
r: onp.ToJustInt,
n: onp.ToJustInt,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...
@overload
def __init__(
self: OrderStatisticDistribution[_CDistT1, tuple[int]],
dist: _CDistT1,
/,
*args: Never,
r: onp.ToJustInt | onp.ToJustIntStrict1D,
n: onp.ToJustInt | onp.ToJustIntStrict1D,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...
@overload
def __init__(
self: OrderStatisticDistribution[_CDistT2, tuple[int, int]],
dist: _CDistT2,
/,
*args: Never,
r: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D,
n: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...
@overload
def __init__(
self: OrderStatisticDistribution[_CDistT3, tuple[int, int, int]],
dist: _CDistT3,
/,
*args: Never,
r: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D | onp.ToJustIntStrict3D,
n: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D | onp.ToJustIntStrict3D,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...
@overload
def __init__(
self: OrderStatisticDistribution[_CDistT, tuple[int, ...]],
X: _CDistT,
/,
*args: Never,
r: onp.ToJustInt | onp.ToJustIntND,
n: onp.ToJustInt | onp.ToJustIntND,
tol: opt.Just[float] | _Null = ...,
validation_policy: _ValidationPolicy = None,
cache_policy: _CachePolicy = None,
) -> None: ...

# without HKT there's no reasonable way tot determine the floating scalar type
class MonotonicTransformedDistribution(
TransformedDistribution[_CDistT_co, np.floating[Any], _ShapeT0_co],
Expand Down Expand Up @@ -591,10 +662,6 @@ class MonotonicTransformedDistribution(
cache_policy: _CachePolicy = None,
) -> None: ...

class OrderStatisticDistribution(TransformedDistribution[_CDistT_co, np.float64, _ShapeT0_co], Generic[_CDistT_co, _ShapeT0_co]):
# TODO(jorenham)
...

class Mixture(_BaseDistribution[_FloatingT_co, tuple[()]], Generic[_FloatingT_co]):
_shape: tuple[()]
_dtype: np.dtype[_FloatingT_co]
Expand All @@ -606,10 +673,8 @@ class Mixture(_BaseDistribution[_FloatingT_co, tuple[()]], Generic[_FloatingT_co
def components(self, /) -> list[_CDist0[_FloatingT_co]]: ...
@property
def weights(self, /) -> onp.Array1D[_FloatingT_co]: ...

#
def __init__(self, /, components: Sequence[_CDist0[_FloatingT_co]], *, weights: onp.ToFloat1D | None = None) -> None: ...

#
@override
def kurtosis(self, /, *, method: _SMomentMethod | None = None) -> _Float: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
Expand All @@ -622,31 +687,73 @@ def truncate(
X: _CDistT0,
lb: onp.ToFloat = ...,
ub: onp.ToFloat = ...,
) -> TruncatedDistribution[_CDistT0, np.floating[Any], tuple[()]]: ...
) -> TruncatedDistribution[_CDistT0, tuple[()]]: ...
@overload
def truncate(
X: _CDistT1,
lb: onp.ToFloat | onp.ToFloatStrict1D = ...,
ub: onp.ToFloat | onp.ToFloatStrict1D = ...,
) -> TruncatedDistribution[_CDistT1, np.floating[Any], tuple[int]]: ...
) -> TruncatedDistribution[_CDistT1, tuple[int]]: ...
@overload
def truncate(
X: _CDistT2,
lb: onp.ToFloat | onp.ToFloatStrict1D | onp.ToFloatStrict2D = ...,
ub: onp.ToFloat | onp.ToFloatStrict1D | onp.ToFloatStrict2D = ...,
) -> TruncatedDistribution[_CDistT2, np.floating[Any], tuple[int, int]]: ...
) -> TruncatedDistribution[_CDistT2, tuple[int, int]]: ...
@overload
def truncate(
X: _CDistT3,
lb: onp.ToFloat | onp.ToFloatStrict1D | onp.ToFloatStrict2D | onp.ToFloatStrict3D = ...,
ub: onp.ToFloat | onp.ToFloatStrict1D | onp.ToFloatStrict2D | onp.ToFloatStrict3D = ...,
) -> TruncatedDistribution[_CDistT3, np.floating[Any], tuple[int, int, int]]: ...
) -> TruncatedDistribution[_CDistT3, tuple[int, int, int]]: ...
@overload
def truncate(
X: _CDistT,
lb: onp.ToFloat | onp.ToFloatND = ...,
ub: onp.ToFloat | onp.ToFloatND = ...,
) -> TruncatedDistribution[_CDistT, np.floating[Any], tuple[int, ...]]: ...
) -> TruncatedDistribution[_CDistT, tuple[int, ...]]: ...

#
@overload
def order_statistic(
X: _CDistT0,
/,
*,
r: onp.ToJustInt,
n: onp.ToJustInt,
) -> OrderStatisticDistribution[_CDistT0, tuple[()]]: ...
@overload
def order_statistic(
X: _CDistT1,
/,
*,
r: onp.ToJustInt | onp.ToJustIntStrict1D,
n: onp.ToJustInt | onp.ToJustIntStrict1D,
) -> OrderStatisticDistribution[_CDistT1, tuple[int]]: ...
@overload
def order_statistic(
X: _CDistT2,
/,
*,
r: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D,
n: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D,
) -> OrderStatisticDistribution[_CDistT2, tuple[int, int]]: ...
@overload
def order_statistic(
X: _CDistT3,
/,
*,
r: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D | onp.ToJustIntStrict3D,
n: onp.ToJustInt | onp.ToJustIntStrict1D | onp.ToJustIntStrict2D | onp.ToJustIntStrict3D,
) -> OrderStatisticDistribution[_CDistT3, tuple[int, int, int]]: ...
@overload
def order_statistic(
X: _CDistT,
/,
*,
r: onp.ToJustInt | onp.ToJustIntND,
n: onp.ToJustInt | onp.ToJustIntND,
) -> OrderStatisticDistribution[_CDistT, tuple[int, ...]]: ...

#
@overload
Expand Down Expand Up @@ -683,3 +790,10 @@ def log(X: _CDistT2, /) -> MonotonicTransformedDistribution[_CDistT2, tuple[int,
def log(X: _CDistT3, /) -> MonotonicTransformedDistribution[_CDistT3, tuple[int, int, int]]: ...
@overload
def log(X: _CDistT, /) -> MonotonicTransformedDistribution[_CDistT, tuple[int, ...]]: ...

# NOTE: These currently don't support >0-d parameters, and it looks like they always return float64, regardless of dtype
@type_check_only
class CustomDistribution(ContinuousDistribution[np.float64, tuple[()]]):
_dtype: np.dtype[np.floating[Any]] # ignored

def make_distribution(dist: rv_continuous) -> type[CustomDistribution]: ...
Loading

0 comments on commit 3160f1d

Please sign in to comment.