diff --git a/cspell.json b/cspell.json index 2c54c99..e361569 100644 --- a/cspell.json +++ b/cspell.json @@ -14,6 +14,7 @@ "proportion", "nullproportion", "ospp", - "unpooled" + "unpooled", + "ndigits" ] } diff --git a/main.py b/main.py index 035e2a5..d808e64 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,21 @@ -from pystatpower.basic import * -from pystatpower.procedures.two_proportion import * +from numbers import Real -a = fun_power(0.05, 64, 64, 0.60, 0.85, Alternative.AlternativeEnum.TWO_SIDED, TestType.EnumTestType.Z_TEST_POOLED) -print(a) +class Parent: + _domain = [1, 2, 3] + + +class Child(Parent): + _domain = [4, 5, 6] + + def __init__(self, value: Real): + if not isinstance(value, Real): + raise TypeError(f"{value} is not a real number") + if value not in type(self)._domain: + raise ValueError(f"{value} is not in {type(self)._domain}") + self._value = value + + +# 测试 +child_instance = Child(5) # 正常 +child_instance_invalid = Child(3) # 抛出 ValueError diff --git a/src/pystatpower/basic.py b/src/pystatpower/basic.py index cf30e84..73b0d93 100644 --- a/src/pystatpower/basic.py +++ b/src/pystatpower/basic.py @@ -1,129 +1,273 @@ -from abc import ABC, abstractmethod -from enum import Enum -from math import inf +from math import ceil, floor, inf, trunc from numbers import Real -from pystatpower.interval import Interval +from dataclasses import dataclass -class Param(ABC): - """抽象参数基类""" - domain = None +@dataclass(frozen=True) +class Interval: + """定义一个区间,可指定是否包含上下限,不支持单点区间(例如:[1, 1])。 - @abstractmethod - def __init__(self, value): - pass + Parameters + ---------- + lower (Real): 区间下限 + upper (Real): 区间上限 + lower_inclusive (bool): 是否包含区间下限 + upper_inclusive (bool): 是否包含区间上限 - @classmethod - @abstractmethod - def _check(cls, domain, value): - pass + Examples + -------- + >>> interval = Interval(0, 1, lower_inclusive=True, upper_inclusive=False) + >>> 0.5 in interval + True + >>> 1 in interval + False + >>> 0 in interval + False + >>> interval.pseudo_bound() + (0, 0.9999999999) + """ + lower: Real + upper: Real + lower_inclusive: bool = False + upper_inclusive: bool = False -class NumericParam(Param): - """数值参数基类""" - - domain = Interval(-inf, inf) + def __contains__(self, value: Real) -> bool: + if not isinstance(value, Real): + return NotImplemented + + if self.lower_inclusive: + if self.upper_inclusive: + return self.lower <= value <= self.upper + else: + return self.lower <= value < self.upper + else: + if self.upper_inclusive: + return self.lower < value <= self.upper + else: + return self.lower < value < self.upper + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Interval): + return NotImplemented + + return (self.lower, self.upper, self.lower_inclusive, self.upper_inclusive) == ( + other.lower, + other.upper, + other.lower_inclusive, + other.upper_inclusive, + ) + + def __repr__(self) -> str: + if self.lower_inclusive: + if self.upper_inclusive: + return f"[{self.lower}, {self.upper}]" + else: + return f"[{self.lower}, {self.upper})" + else: + if self.upper_inclusive: + return f"({self.lower}, {self.upper}]" + else: + return f"({self.lower}, {self.upper})" + + def pseudo_lbound(self, eps: Real = 1e-10) -> Real: + """区间的伪下界,用于数值计算。""" + if self.lower_inclusive: + return self.lower + else: + return self.lower + eps + + def pseudo_ubound(self, eps: Real = 1e-10) -> Real: + """区间的伪上界,用于数值计算。""" + if self.upper_inclusive: + return self.upper + else: + return self.upper - eps + + def pseudo_bound(self, eps: Real) -> tuple[Real, Real]: + """区间的伪上下界,用于数值计算。""" + return (self.pseudo_lbound(eps), self.pseudo_ubound(eps)) + + +class NumericParam(Real): + + _domain = Interval(-inf, inf) def __init__(self, value: Real): - cls = type(self) - cls._check(value) + if not isinstance(value, Real): + raise TypeError(f"{value} is not a real number") + if value not in self._domain: + raise ValueError(f"{value} is not in {self._domain}") self._value = value - @property - def value(self): - return self._value + def __repr__(self): + return f"{type(self).__name__}({self._value})" + + def __add__(self, other): + if isinstance(other, Real): + return type(self)(self._value + other) + return NotImplemented + + def __sub__(self, other): + if isinstance(other, Real): + return type(self)(self._value - other) + return NotImplemented - @classmethod - def _check(cls, value: Real): - domain = cls.domain - if not isinstance(value, Real): - raise TypeError(f"{value} is not a real number") - if value not in domain: - raise ValueError(f"{value} is not in {domain}") + def __mul__(self, other): + if isinstance(other, Real): + return type(self)(self._value * other) + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, Real): + return type(self)(self._value / other) + return NotImplemented + + def __floordiv__(self, other): + if isinstance(other, Real): + return type(self)(self._value // other) + return NotImplemented + + def __mod__(self, other): + if isinstance(other, Real): + return type(self)(self._value % other) + return NotImplemented + + def __pow__(self, other): + if isinstance(other, Real): + return type(self)(self._value**other) + return NotImplemented + + def __radd__(self, other): + if isinstance(other, Real): + return type(self)(other + self._value) + return NotImplemented + + def __rfloordiv__(self, other): + if isinstance(other, Real): + return type(self)(other // self._value) + return NotImplemented + + def __rmul__(self, other): + if isinstance(other, Real): + return type(self)(other * self._value) + return NotImplemented + + def __rmod__(self, other): + if isinstance(other, Real): + return type(self)(other % self._value) + return NotImplemented + + def __rpow__(self, base): + if isinstance(base, Real): + return type(self)(base**self._value) + return NotImplemented + + def __rtruediv__(self, other): + if isinstance(other, Real): + return type(self)(other / self._value) + return NotImplemented + + def __trunc__(self): + return type(self)(trunc(self._value)) + + def __neg__(self): + return type(self)(-self._value) + + def __pos__(self): + return type(self)(+self._value) + + def __abs__(self): + return type(self)(abs(self._value)) + + def __ceil__(self): + return type(self)(ceil(self._value)) + def __floor__(self): + return type(self)(floor(self._value)) -class OptionalParam(Param): - """选项参数基类""" + def __round__(self, ndigits=None): + return type(self)(round(self._value, ndigits)) - class EmptyEnum(Enum): - pass + def __eq__(self, other): + if isinstance(other, Real): + return self._value == other + return NotImplemented - domain = EmptyEnum + def __lt__(self, other): + if isinstance(other, Real): + return self._value < other + return NotImplemented - def __init__(self, value: Enum | str): - cls = type(self) - self._value = cls._check(value) + def __le__(self, other): + if isinstance(other, Real): + return self._value <= other + return NotImplemented - @property - def value(self): - return self._value + def __float__(self): + return float(self._value) - @classmethod - def _check(cls, value: Enum | str) -> Enum: - domain = cls.domain + def __complex__(self): + return complex(self._value) - if isinstance(value, str): - try: - value = domain[value.upper()] - except KeyError: - raise ValueError(f"No such option '{value}' in {domain.__name__}") - elif not isinstance(value, domain): - raise TypeError(f"{value} is not a {domain.__name__}") + def __hash__(self): + return hash(self._value) - return value + def __bool__(self): + return bool(self._value) class Alpha(NumericParam): """显著性水平""" - domain = Interval(0, 1) + _domain = Interval(0, 1) class Power(NumericParam): """检验效能""" - domain = Interval(0, 1) + _domain = Interval(0, 1) class Mean(NumericParam): """均值""" - domain = Interval(-inf, inf) + _domain = Interval(-inf, inf) class STD(NumericParam): """标准差""" - domain = Interval(0, inf) + _domain = Interval(0, inf) class Proportion(NumericParam): """率""" - domain = Interval(0, 1) + _domain = Interval(0, 1) class Percent(NumericParam): """百分比""" - domain = Interval(0, 1) + _domain = Interval(0, 1) class Ratio(NumericParam): """比例""" - domain = Interval(0, inf) + _domain = Interval(0, inf) class Size(NumericParam): """样本量""" - domain = Interval(0, inf) + _domain = Interval(0, inf) class DropOutRate(NumericParam): """脱落率""" - domain = Interval(0, 1, lower_inclusive=True) + _domain = Interval(0, 1, lower_inclusive=True) diff --git a/src/pystatpower/procedures/two_proportion.py b/src/pystatpower/procedures/two_proportion.py index 89b06d0..0621773 100644 --- a/src/pystatpower/procedures/two_proportion.py +++ b/src/pystatpower/procedures/two_proportion.py @@ -8,7 +8,7 @@ from scipy.stats import norm from scipy.optimize import brenth -from pystatpower.basic import OptionalParam, Percent, Proportion, Ratio +from pystatpower.basic import Alpha, Power, Proportion class Alternative(Enum): @@ -126,13 +126,6 @@ def fun_power( return power -class TwoProportion: - - @abstractmethod - def solve(self): - pass - - # solve for sample size @@ -184,25 +177,25 @@ def __init__( # pass -class TwoProportionSolveForSize(TwoProportion): +class TwoProportionSolveForSize: def __init__( self, - alpha: Real, - power: Real, + alpha: Alpha, + power: Power, alternative: Alternative, test_type: TestType, - treatment_proportion: Real, - reference_proportion: Real, + treatment_proportion: Proportion, + reference_proportion: Proportion, group_allocation: GroupAllocation, ): - self._alpha = alpha - self._power = power - self._alternative = alternative - self._test_type = test_type - self._treatment_proportion = treatment_proportion - self._reference_proportion = reference_proportion - self._group_allocation = group_allocation + self.alpha = alpha + self.power = power + self.alternative = alternative + self.test_type = test_type + self.treatment_proportion = treatment_proportion + self.reference_proportion = reference_proportion + self.group_allocation = group_allocation @property def alpha(self) -> float: diff --git a/tests/test_basic.py b/tests/test_basic.py index afd9f8e..cde3c16 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,4 +1,7 @@ +from enum import Enum + import pytest + from pystatpower.basic import * @@ -7,8 +10,8 @@ class ErrorEnum(Enum): def test_alpha(): - assert Alpha(0.05).value == 0.05 - assert Alpha(0.001).value == 0.001 + assert Alpha(0.05) == 0.05 + assert Alpha(0.001) == 0.001 with pytest.raises(ValueError): Alpha(-1) @@ -21,8 +24,8 @@ def test_alpha(): def test_power(): - assert Power(0.05).value == 0.05 - assert Power(0.001).value == 0.001 + assert Power(0.05) == 0.05 + assert Power(0.001) == 0.001 with pytest.raises(ValueError): Power(-1) @@ -35,9 +38,9 @@ def test_power(): def test_mean(): - assert Mean(-10).value == -10 - assert Mean(0).value == 0 - assert Mean(10).value == 10 + assert Mean(-10) == -10 + assert Mean(0) == 0 + assert Mean(10) == 10 with pytest.raises(ValueError): Mean(-inf) @@ -46,7 +49,7 @@ def test_mean(): def test_std(): - assert STD(10).value == 10 + assert STD(10) == 10 with pytest.raises(ValueError): STD(-10) @@ -57,7 +60,7 @@ def test_std(): def test_proportion(): - assert Proportion(0.5).value == 0.5 + assert Proportion(0.5) == 0.5 with pytest.raises(ValueError): Proportion(-1) @@ -70,8 +73,8 @@ def test_proportion(): def test_size(): - assert Size(20).value == 20 - assert Size(20.142857).value == 20.142857 + assert Size(20) == 20 + assert Size(20.142857) == 20.142857 with pytest.raises(ValueError): Size(-1) @@ -82,8 +85,8 @@ def test_size(): def test_dropout_rate(): - assert DropOutRate(0).value == 0 - assert DropOutRate(0.5).value == 0.5 + assert DropOutRate(0) == 0 + assert DropOutRate(0.5) == 0.5 with pytest.raises(ValueError): DropOutRate(-1)