From 35866f98b298403b733b7ef57998b90548f23e49 Mon Sep 17 00:00:00 2001 From: Kun Jinkao <45487685+Snoopy1866@users.noreply.github.com> Date: Fri, 20 Sep 2024 10:04:17 +0800 Subject: [PATCH] update --- cspell.json | 11 +- main.py | 35 ++ src/pystatpower/basic.py | 296 +++++++++++ src/pystatpower/interval.py | 87 ---- .../procedures/one_sample_proportion.py | 2 +- src/pystatpower/procedures/two_proportion.py | 458 ++++++++++++++++++ tests/test_basic.py | 451 +++++++++++++++++ 7 files changed, 1251 insertions(+), 89 deletions(-) create mode 100644 main.py create mode 100644 src/pystatpower/basic.py delete mode 100644 src/pystatpower/interval.py create mode 100644 src/pystatpower/procedures/two_proportion.py create mode 100644 tests/test_basic.py diff --git a/cspell.json b/cspell.json index 2e6c441..403777d 100644 --- a/cspell.json +++ b/cspell.json @@ -13,6 +13,15 @@ "brenth", "proportion", "nullproportion", - "ospp" + "ospp", + "unpooled", + "ndigits", + "radd", + "rmul", + "rtruediv", + "rfloordiv", + "rmod", + "rpow", + "brentq" ] } diff --git a/main.py b/main.py new file mode 100644 index 0000000..c9e9b90 --- /dev/null +++ b/main.py @@ -0,0 +1,35 @@ +from objprint import op + +from pystatpower.procedures.two_proportion import * + +a = solve_for_sample_size( + alpha=0.05, + power=0.8, + alternative="two_sided", + test_type="z_test_pooled", + treatment_proportion=0.65, + reference_proportion=0.85, + full_output=True, +) + +op(a) + +a = GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.PERCENT_OF_TREATMENT +b = GroupAllocationOption.PERCENT_OF_TREATMENT | GroupAllocationOption.SIZE_OF_TOTAL + +print(a == b) + +a = solve_for_power( + alpha=0.05, + alternative="two_sided", + test_type="z_test_pooled", + treatment_proportion=0.65, + reference_proportion=0.85, + group_allocation=GroupAllocationSolveForPower( + GroupAllocationOption.PERCENT_OF_TREATMENT | GroupAllocationOption.SIZE_OF_TOTAL, + size_of_total=100, + percent_of_treatment=0.50, + ), + full_output=True, +) +op(a) diff --git a/src/pystatpower/basic.py b/src/pystatpower/basic.py new file mode 100644 index 0000000..530af88 --- /dev/null +++ b/src/pystatpower/basic.py @@ -0,0 +1,296 @@ +from math import ceil, floor, inf, isclose, trunc +from numbers import Real + + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Interval: + """定义一个区间,可指定是否包含上下限,不支持单点区间(例如:[1, 1])。 + + Parameters + ---------- + lower (Real): 区间下限 + upper (Real): 区间上限 + lower_inclusive (bool): 是否包含区间下限 + upper_inclusive (bool): 是否包含区间上限 + + Examples + -------- + >>> interval = Interval(0, 1, lower_inclusive=True, upper_inclusive=False) + >>> 0.5 in interval + True + >>> 1 in interval + False + >>> 0 in interval + False + >>> interval.pseudo_bound() + (0, 0.9999999999) + """ + + lower: Real + upper: Real + lower_inclusive: bool = False + upper_inclusive: bool = False + + def __contains__(self, value: Real) -> bool: + if isinstance(value, Real): + if self.lower_inclusive: + if self.upper_inclusive: + return self.lower <= value <= self.upper + else: + return self.lower <= value < self.upper + else: + if self.upper_inclusive: + return self.lower < value <= self.upper + else: + return self.lower < value < self.upper + + raise RuntimeError(f"Interval.__contains__ only supports real numbers, but you passed in a {type(value)}.") + + def __eq__(self, other: object) -> bool: + if isinstance(other, Interval): + return ( + isclose(self.lower, other.lower) + and isclose(self.upper, other.upper) + and self.lower_inclusive == other.lower_inclusive + and self.upper_inclusive == other.upper_inclusive + ) + + raise RuntimeError(f"Interval.__eq__ only supports Interval, but you passed in a {type(other)}.") + + def __repr__(self) -> str: + if self.lower_inclusive: + if self.upper_inclusive: + return f"[{self.lower}, {self.upper}]" + else: + return f"[{self.lower}, {self.upper})" + else: + if self.upper_inclusive: + return f"({self.lower}, {self.upper}]" + else: + return f"({self.lower}, {self.upper})" + + def pseudo_lbound(self, eps: Real = 1e-10) -> Real: + """区间的伪下界,用于数值计算。""" + if self.lower_inclusive: + return self.lower + else: + return self.lower + eps + + def pseudo_ubound(self, eps: Real = 1e-10) -> Real: + """区间的伪上界,用于数值计算。""" + if self.upper_inclusive: + return self.upper + else: + return self.upper - eps + + def pseudo_bound(self, eps: Real = 1e-10) -> tuple[Real, Real]: + """区间的伪上下界,用于数值计算。""" + return (self.pseudo_lbound(eps), self.pseudo_ubound(eps)) + + +class PowerAnalysisNumeric(Real): + + _domain = Interval(-inf, inf, lower_inclusive=True, upper_inclusive=True) + + def __init__(self, value: Real): + if not isinstance(value, Real): + raise TypeError(f"{value} is not a real number.") + if value not in type(self)._domain: + raise ValueError(f"{value} is not in {type(self)._domain}.") + self._value = value + + def __repr__(self): + return f"{type(self).__name__}({self._value})" + + def __add__(self, other): + if isinstance(other, Real): + return self._value + other + return NotImplemented + + def __radd__(self, other): + if isinstance(other, Real): + return other + self._value + return NotImplemented + + def __sub__(self, other): + if isinstance(other, Real): + return self._value - other + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, Real): + return other - self._value + return NotImplemented + + def __mul__(self, other): + if isinstance(other, Real): + return self._value * other + return NotImplemented + + def __rmul__(self, other): + if isinstance(other, Real): + return other * self._value + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, Real): + return self._value / other + return NotImplemented + + def __rtruediv__(self, other): + if isinstance(other, Real): + return other / self._value + return NotImplemented + + def __floordiv__(self, other): + if isinstance(other, Real): + return self._value // other + return NotImplemented + + def __rfloordiv__(self, other): + if isinstance(other, Real): + return other // self._value + return NotImplemented + + def __mod__(self, other): + if isinstance(other, Real): + return self._value % other + return NotImplemented + + def __rmod__(self, other): + if isinstance(other, Real): + return other % self._value + return NotImplemented + + def __pow__(self, other): + if isinstance(other, Real): + return self._value**other + return NotImplemented + + def __rpow__(self, base): + if isinstance(base, Real): + return base**self._value + return NotImplemented + + def __abs__(self): + return abs(self._value) + + def __neg__(self): + return -self._value + + def __pos__(self): + return +self._value + + def __trunc__(self): + return trunc(self._value) + + def __floor__(self): + return floor(self._value) + + def __ceil__(self): + return ceil(self._value) + + def __round__(self, ndigits=None): + return round(self._value, ndigits) + + def __eq__(self, other): + if isinstance(other, Real): + return self._value == other + raise RuntimeError(f"{type(self)}.__eq__ only supports real numbers, but you passed in a {type(other)}.") + + def __ne__(self, other): + if isinstance(other, Real): + return self._value != other + raise RuntimeError(f"{type(self)}.__ne__ only supports real numbers, but you passed in a {type(other)}.") + + def __lt__(self, other): + if isinstance(other, Real): + return self._value < other + raise RuntimeError(f"{type(self)}.__lt__ only supports real numbers, but you passed in a {type(other)}.") + + def __le__(self, other): + if isinstance(other, Real): + return self._value <= other + raise RuntimeError(f"{type(self)}.__le__ only supports real numbers, but you passed in a {type(other)}.") + + def __gt__(self, other): + if isinstance(other, Real): + return self._value > other + raise RuntimeError(f"{type(self)}.__gt__ only supports real numbers, but you passed in a {type(other)}.") + + def __ge__(self, other): + if isinstance(other, Real): + return self._value >= other + raise RuntimeError(f"{type(self)}.__ge__ only supports real numbers, but you passed in a {type(other)}.") + + def __int__(self): + return int(self._value) + + def __float__(self): + return float(self._value) + + def __complex__(self): + return complex(self._value) + + def __hash__(self): + return hash(self._value) + + def __bool__(self): + return bool(self._value) + + +class Alpha(PowerAnalysisNumeric): + """显著性水平""" + + _domain = Interval(0, 1) + + +class Power(PowerAnalysisNumeric): + """检验效能""" + + _domain = Interval(0, 1) + + +class Mean(PowerAnalysisNumeric): + """均值""" + + _domain = Interval(-inf, inf) + + +class STD(PowerAnalysisNumeric): + """标准差""" + + _domain = Interval(0, inf) + + +class Proportion(PowerAnalysisNumeric): + """率""" + + _domain = Interval(0, 1) + + +class Percent(PowerAnalysisNumeric): + """百分比""" + + _domain = Interval(0, 1) + + +class Ratio(PowerAnalysisNumeric): + """比例""" + + _domain = Interval(0, inf) + + +class Size(PowerAnalysisNumeric): + """样本量""" + + _domain = Interval(0, inf) + + +class DropOutRate(PowerAnalysisNumeric): + """脱落率""" + + _domain = Interval(0, 1, lower_inclusive=True) diff --git a/src/pystatpower/interval.py b/src/pystatpower/interval.py deleted file mode 100644 index e5b14ae..0000000 --- a/src/pystatpower/interval.py +++ /dev/null @@ -1,87 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class Interval: - """定义一个区间,可指定是否包含上下限,不支持单点区间(例如:[1, 1])。 - - Parameters - ---------- - lower (Any): 区间下限 - upper (Any): 区间上限 - lower_inclusive (bool): 是否包含区间下限 - upper_inclusive (bool): 是否包含区间上限 - - Examples - -------- - >>> interval = Interval(0, 1, lower_inclusive=True, upper_inclusive=False) - >>> 0.5 in interval - True - >>> 1 in interval - False - >>> 0 in interval - False - >>> interval.pseudo_bound() - (0, 0.9999999999) - """ - - lower: int | float - upper: int | float - lower_inclusive: bool = False - upper_inclusive: bool = False - - def __contains__(self, value: int | float) -> bool: - if not isinstance(value, (int, float)): - raise TypeError(f"unsupported operand type(s) for in: 'Interval' and '{type(value)}'") - - if self.lower_inclusive: - if self.upper_inclusive: - return self.lower <= value <= self.upper - else: - return self.lower <= value < self.upper - else: - if self.upper_inclusive: - return self.lower < value <= self.upper - else: - return self.lower < value < self.upper - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Interval): - raise NotImplementedError(f"unsupported operand type(s) for ==: 'Interval' and '{type(other)}'") - - return (self.lower, self.upper, self.lower_inclusive, self.upper_inclusive) == ( - other.lower, - other.upper, - other.lower_inclusive, - other.upper_inclusive, - ) - - def __repr__(self) -> str: - if self.lower_inclusive: - if self.upper_inclusive: - return f"[{self.lower}, {self.upper}]" - else: - return f"[{self.lower}, {self.upper})" - else: - if self.upper_inclusive: - return f"({self.lower}, {self.upper}]" - else: - return f"({self.lower}, {self.upper})" - - def pseudo_lbound(self, eps=1e-10) -> int | float: - """区间的伪下界,用于数值计算。""" - if self.lower_inclusive: - return self.lower - else: - return self.lower + eps - - def pseudo_ubound(self, eps=1e-10) -> int | float: - """区间的伪上界,用于数值计算。""" - if self.upper_inclusive: - return self.upper - else: - return self.upper - eps - - def pseudo_bound(self) -> tuple[int | float, int | float]: - """区间的伪上下界,用于数值计算。""" - return (self.pseudo_lbound(), self.pseudo_ubound()) diff --git a/src/pystatpower/procedures/one_sample_proportion.py b/src/pystatpower/procedures/one_sample_proportion.py index eee3724..9f5ca1e 100644 --- a/src/pystatpower/procedures/one_sample_proportion.py +++ b/src/pystatpower/procedures/one_sample_proportion.py @@ -17,7 +17,7 @@ TargetParameterNotExistError, ) from pystatpower.utils import get_enum_by_name -from pystatpower.interval import Interval +from pystatpower.basic import Interval # 最大样本量 diff --git a/src/pystatpower/procedures/two_proportion.py b/src/pystatpower/procedures/two_proportion.py new file mode 100644 index 0000000..d2803ea --- /dev/null +++ b/src/pystatpower/procedures/two_proportion.py @@ -0,0 +1,458 @@ +"""两独立样本差异性检验""" + +from enum import Enum, Flag, auto, unique +from math import sqrt + +from scipy.stats import norm +from scipy.optimize import brentq + +from pystatpower.basic import Alpha, Power, Proportion, Size + + +# __all__ = ["Alternative", "TestType", "GroupAllocationOption", "TwoProportionSolveForSize", "solve_for_sample_size"] + + +@unique +class Alternative(Enum): + """备择假设类型 + + Attributes + ---------- + TWO_SIDED : (int) + 双侧检验 + ONE_SIDED : (int) + 单侧检验 + """ + + TWO_SIDED = 1 + ONE_SIDED = 2 + + +@unique +class TestType(Enum): + """检验类型 + + Attributes + ---------- + Z_TEST_POOLED : (int) + Z 检验(合并方差) + Z_TEST_UNPOOLED : (int) + Z 检验(独立方差) + Z_TEST_CC_POOLED : (int) + Z 检验(连续性校正,合并方差) + Z_TEST_CC_UNPOOLED : (int) + Z 检验(连续性校正,独立方差) + """ + + Z_TEST_POOLED = 1 + Z_TEST_UNPOOLED = 2 + Z_TEST_CC_POOLED = 3 + Z_TEST_CC_UNPOOLED = 4 + + +class GroupAllocationOption(Flag): + """样本量分配类型(求解目标:样本量) + + Attributes + ---------- + EQUAL + 等量分配 + SIZE_OF_TOTAL + 总样本量 + SIZE_OF_EACH + 单组样本量 + SIZE_OF_TREATMENT + 试验组样本量 + SIZE_OF_REFERENCE + 对照组样本量 + RATIO_OF_TREATMENT_TO_REFERENCE + 试验组与对照组样本量比例 + RATIO_OF_REFERENCE_TO_TREATMENT + 对照组与试验组样本量比例 + PERCENT_OF_TREATMENT + 试验组样本百分比 + PERCENT_OF_REFERENCE + 对照组样本百分比 + """ + + EQUAL = auto() + SIZE_OF_TOTAL = auto() + SIZE_OF_EACH = auto() + SIZE_OF_TREATMENT = auto() + SIZE_OF_REFERENCE = auto() + RATIO_OF_TREATMENT_TO_REFERENCE = auto() + RATIO_OF_REFERENCE_TO_TREATMENT = auto() + PERCENT_OF_TREATMENT = auto() + PERCENT_OF_REFERENCE = auto() + + +class GroupAllocationSolveForSize: + def __init__( + self, + group_allocation_option: GroupAllocationOption = GroupAllocationOption.EQUAL, + size_of_treatment: float = None, + size_of_reference: float = None, + ratio_of_treatment_to_reference: float = None, + ratio_of_reference_to_treatment: float = None, + percent_of_treatment: float = None, + percent_of_reference: float = None, + ): + match group_allocation_option: + case GroupAllocationOption.EQUAL: + self.treatment_size_formula = lambda n: n + self.reference_size_formula = lambda n: n + case GroupAllocationOption.SIZE_OF_TREATMENT: + self.treatment_size_formula = lambda n: size_of_treatment + self.reference_size_formula = lambda n: n + case GroupAllocationOption.SIZE_OF_REFERENCE: + self.treatment_size_formula = lambda n: n + self.reference_size_formula = lambda n: size_of_reference + case GroupAllocationOption.RATIO_OF_TREATMENT_TO_REFERENCE: + self.treatment_size_formula = lambda n: ratio_of_treatment_to_reference * n + self.reference_size_formula = lambda n: n + case GroupAllocationOption.RATIO_OF_REFERENCE_TO_TREATMENT: + self.treatment_size_formula = lambda n: n + self.reference_size_formula = lambda n: ratio_of_reference_to_treatment * n + case GroupAllocationOption.PERCENT_OF_TREATMENT: + self.treatment_size_formula = lambda n: n + self.reference_size_formula = lambda n: (1 - percent_of_treatment) / percent_of_treatment * n + case GroupAllocationOption.PERCENT_OF_REFERENCE: + self.treatment_size_formula = lambda n: (1 - percent_of_reference) / percent_of_reference * n + self.reference_size_formula = lambda n: n + case _: + raise ValueError("未知的样本量分配类型") + + +class GroupAllocationSolveForPower: + def __init__( + self, + group_allocation_option: GroupAllocationOption, + size_of_total: float = None, + size_of_each: float = None, + size_of_treatment: float = None, + size_of_reference: float = None, + ratio_of_treatment_to_reference: float = None, + ratio_of_reference_to_treatment: float = None, + percent_of_treatment: float = None, + percent_of_reference: float = None, + ): + match group_allocation_option: + case x if x == GroupAllocationOption.EQUAL | GroupAllocationOption.SIZE_OF_TOTAL: + self.treatment_size_formula = lambda: size_of_total / 2 + self.reference_size_formula = lambda: size_of_total / 2 + case x if x == GroupAllocationOption.EQUAL | GroupAllocationOption.SIZE_OF_EACH: + self.treatment_size_formula = lambda: size_of_each + self.reference_size_formula = lambda: size_of_each + case x if x == GroupAllocationOption.EQUAL | GroupAllocationOption.SIZE_OF_TREATMENT: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = lambda: size_of_treatment + case x if x == GroupAllocationOption.EQUAL | GroupAllocationOption.SIZE_OF_REFERENCE: + self.treatment_size_formula = lambda: size_of_reference + self.reference_size_formula = lambda: size_of_reference + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.SIZE_OF_TREATMENT: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = lambda: size_of_total - size_of_treatment + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.SIZE_OF_REFERENCE: + self.treatment_size_formula = lambda: size_of_total - size_of_reference + self.reference_size_formula = lambda: size_of_reference + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.RATIO_OF_TREATMENT_TO_REFERENCE: + self.treatment_size_formula = ( + lambda: size_of_total * ratio_of_treatment_to_reference / (1 + ratio_of_treatment_to_reference) + ) + self.reference_size_formula = lambda: size_of_total / (1 + ratio_of_treatment_to_reference) + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.RATIO_OF_REFERENCE_TO_TREATMENT: + self.treatment_size_formula = lambda: size_of_total / (1 + ratio_of_reference_to_treatment) + self.reference_size_formula = ( + size_of_total * ratio_of_reference_to_treatment / (1 + ratio_of_reference_to_treatment) + ) + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.PERCENT_OF_TREATMENT: + self.treatment_size_formula = lambda: size_of_total * percent_of_treatment + self.reference_size_formula = lambda: size_of_total * (1 - percent_of_treatment) + case x if x == GroupAllocationOption.SIZE_OF_TOTAL | GroupAllocationOption.PERCENT_OF_REFERENCE: + self.treatment_size_formula = lambda: size_of_total * (1 - percent_of_reference) + self.reference_size_formula = lambda: size_of_total * percent_of_reference + case x if x == GroupAllocationOption.SIZE_OF_EACH: + self.treatment_size_formula = lambda: size_of_each + self.reference_size_formula = lambda: size_of_each + case x if x == GroupAllocationOption.SIZE_OF_TREATMENT | GroupAllocationOption.SIZE_OF_REFERENCE: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = lambda: size_of_reference + case ( + x + ) if x == GroupAllocationOption.SIZE_OF_TREATMENT | GroupAllocationOption.RATIO_OF_TREATMENT_TO_REFERENCE: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = lambda: size_of_treatment / ratio_of_treatment_to_reference + case ( + x + ) if x == GroupAllocationOption.SIZE_OF_TREATMENT | GroupAllocationOption.RATIO_OF_REFERENCE_TO_TREATMENT: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = lambda: size_of_treatment * ratio_of_reference_to_treatment + case x if x == GroupAllocationOption.SIZE_OF_TREATMENT | GroupAllocationOption.PERCENT_OF_TREATMENT: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = ( + lambda: size_of_treatment * (1 - percent_of_treatment) / percent_of_treatment + ) + case x if x == GroupAllocationOption.SIZE_OF_TREATMENT | GroupAllocationOption.PERCENT_OF_REFERENCE: + self.treatment_size_formula = lambda: size_of_treatment + self.reference_size_formula = ( + lambda: size_of_treatment * percent_of_reference / (1 - percent_of_reference) + ) + case ( + x + ) if x == GroupAllocationOption.SIZE_OF_REFERENCE | GroupAllocationOption.RATIO_OF_TREATMENT_TO_REFERENCE: + self.treatment_size_formula = lambda: size_of_reference * ratio_of_treatment_to_reference + self.reference_size_formula = lambda: size_of_reference + case ( + x + ) if x == GroupAllocationOption.SIZE_OF_REFERENCE | GroupAllocationOption.RATIO_OF_REFERENCE_TO_TREATMENT: + self.treatment_size_formula = lambda: size_of_reference / ratio_of_reference_to_treatment + self.reference_size_formula = lambda: size_of_reference + case x if x == GroupAllocationOption.SIZE_OF_REFERENCE | GroupAllocationOption.PERCENT_OF_TREATMENT: + self.treatment_size_formula = ( + lambda: size_of_reference * percent_of_treatment / (1 - percent_of_treatment) + ) + self.reference_size_formula = lambda: size_of_reference + case x if x == GroupAllocationOption.SIZE_OF_REFERENCE | GroupAllocationOption.PERCENT_OF_REFERENCE: + self.treatment_size_formula = ( + lambda: size_of_reference * (1 - percent_of_reference) / percent_of_reference + ) + self.reference_size_formula = lambda: size_of_reference + case _: + raise ValueError("未知的样本量分配类型") + + +GroupAllocationSolveForAlpha = GroupAllocationSolveForPower + + +def fun_power( + alpha: float, + treatment_n: float, + reference_n: float, + treatment_proportion: float, + reference_proportion: float, + alternative: Alternative, + test_type: TestType, +): + n1 = treatment_n + n2 = reference_n + p1 = treatment_proportion + p2 = reference_proportion + + # 计算标准误 + match test_type: + case TestType.Z_TEST_POOLED | TestType.Z_TEST_CC_POOLED: + p_hat = (n1 * p1 + n2 * p2) / (n1 + n2) + se = sqrt(p_hat * (1 - p_hat) * (1 / n1 + 1 / n2)) + case TestType.Z_TEST_UNPOOLED | TestType.Z_TEST_CC_UNPOOLED: + se = sqrt(p1 * (1 - p1) / n1 + p2 * (1 - p2) / n2) + case _: + assert False, "未知的检验类型" + + # 连续性校正 + c = 0 + if test_type in [TestType.Z_TEST_CC_POOLED, TestType.Z_TEST_CC_UNPOOLED]: + c = (1 / 2) * (1 / n1 + 1 / n2) + + # 计算检验效能 + match alternative: + case Alternative.TWO_SIDED: + z_alpha = norm.ppf(1 - alpha / 2) + z_stat = [(p1 - p2 - c) / se, (p1 - p2 + c) / se] + power = norm.cdf(-z_alpha - z_stat[0]) + 1 - norm.cdf(z_alpha - z_stat[1]) + case Alternative.ONE_SIDED: + z_alpha = norm.ppf(1 - alpha) + if p1 > p2: + z_stat = (p1 - p2 + c) / se + power = 1 - norm.cdf(z_alpha - z_stat) + elif p1 <= p2: + z_stat = (p1 - p2 - c) / se + power = norm.cdf(-z_alpha - z_stat) + case _: + assert False, "未知的备择假设类型" + + return power + + +class TwoProportionSolveForSize: + def __init__( + self, + alpha: Alpha, + power: Power, + alternative: Alternative, + test_type: TestType, + treatment_proportion: Proportion, + reference_proportion: Proportion, + group_allocation: GroupAllocationSolveForSize, + ): + self.alpha = alpha + self.power = power + self.alternative = alternative + self.test_type = test_type + self.treatment_proportion = treatment_proportion + self.reference_proportion = reference_proportion + self.group_allocation = group_allocation + + def solve(self): + self._eval = ( + lambda n: fun_power( + self.alpha, + self.group_allocation.treatment_size_formula(n), + self.group_allocation.reference_size_formula(n), + self.treatment_proportion, + self.reference_proportion, + self.alternative, + self.test_type, + ) + - self.power + ) + try: + n = brentq(self._eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + self.treatment_size = Size(self.group_allocation.treatment_size_formula(n)) + self.reference_size = Size(self.group_allocation.reference_size_formula(n)) + + +class TwoProportionSolveForPower: + def __init__( + self, + alpha: Alpha, + alternative: Alternative, + test_type: TestType, + treatment_proportion: Proportion, + reference_proportion: Proportion, + group_allocation: GroupAllocationSolveForPower, + ): + self.alpha = alpha + self.alternative = alternative + self.test_type = test_type + self.treatment_proportion = treatment_proportion + self.reference_proportion = reference_proportion + self.group_allocation = group_allocation + + def solve(self): + power = fun_power( + self.alpha, + self.group_allocation.treatment_size_formula(), + self.group_allocation.reference_size_formula(), + self.treatment_proportion, + self.reference_proportion, + self.alternative, + self.test_type, + ) + self.power = Power(power) + self.treatment_size = Size(self.group_allocation.treatment_size_formula()) + self.reference_size = Size(self.group_allocation.reference_size_formula()) + + +class TwoProportionSolveForAlpha: + def __init__( + self, + power: Power, + alternative: Alternative, + test_type: TestType, + treatment_proportion: Proportion, + reference_proportion: Proportion, + group_allocation: GroupAllocationSolveForAlpha, + ): + self.power = power + self.alternative = alternative + self.test_type = test_type + self.treatment_proportion = treatment_proportion + self.reference_proportion = reference_proportion + self.group_allocation = group_allocation + + def solve(self): + self._eval = ( + lambda alpha: fun_power( + alpha, + self.group_allocation.treatment_size_formula(), + self.group_allocation.reference_size_formula(), + self.treatment_proportion, + self.reference_proportion, + self.alternative, + self.test_type, + ) + - self.power + ) + try: + alpha = brentq(self._eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + self.alpha = Alpha(alpha) + self.treatment_size = Size(self.group_allocation.treatment_size_formula()) + self.reference_size = Size(self.group_allocation.reference_size_formula()) + + +def solve_for_sample_size( + alpha: float, + power: float, + alternative: str, + test_type: str, + treatment_proportion: float, + reference_proportion: float, + group_allocation: GroupAllocationSolveForSize = GroupAllocationSolveForSize(GroupAllocationOption.EQUAL), + full_output: bool = False, +): + model = TwoProportionSolveForSize( + Alpha(alpha), + Power(power), + Alternative[alternative.upper()], + TestType[test_type.upper()], + Proportion(treatment_proportion), + Proportion(reference_proportion), + group_allocation, + ) + model.solve() + + if full_output: + return model + return model.treatment_size, model.reference_size + + +def solve_for_power( + alpha: float, + alternative: str, + test_type: str, + treatment_proportion: float, + reference_proportion: float, + group_allocation: GroupAllocationSolveForPower, + full_output: bool = False, +): + model = TwoProportionSolveForPower( + Alpha(alpha), + Alternative[alternative.upper()], + TestType[test_type.upper()], + Proportion(treatment_proportion), + Proportion(reference_proportion), + group_allocation, + ) + model.solve() + + if full_output: + return model + return model.power + + +def solve_for_alpha( + power: float, + alternative: str, + test_type: str, + treatment_proportion: float, + reference_proportion: float, + group_allocation: GroupAllocationSolveForAlpha, + full_output: bool = False, +): + model = TwoProportionSolveForAlpha( + Power(power), + Alternative[alternative.upper()], + TestType[test_type.upper()], + Proportion(treatment_proportion), + Proportion(reference_proportion), + group_allocation, + ) + model.solve() + + if full_output: + return model + return model.alpha diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..dcb98ce --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,451 @@ +from enum import Enum +from math import nan + +import pytest + +from pystatpower.basic import * + + +class TestInterval: + def test_contains(self): + assert 0.5 in Interval(0, 1) + assert 0 in Interval(0, 1, lower_inclusive=True) + assert 1 in Interval(0, 1, upper_inclusive=True) + assert 0 in Interval(0, 1, lower_inclusive=True, upper_inclusive=True) + assert 1 in Interval(0, 1, lower_inclusive=True, upper_inclusive=True) + + with pytest.raises(RuntimeError): + assert "0.5" in Interval(0, 1) + + def test_eq(self): + assert Interval(0, 1) == Interval(0, 1) + assert Interval(0, 1, lower_inclusive=True) == Interval(0, 1, lower_inclusive=True) + assert Interval(0, 1, upper_inclusive=True) == Interval(0, 1, upper_inclusive=True) + assert Interval(0, 1, lower_inclusive=True, upper_inclusive=True) == Interval( + 0, 1, lower_inclusive=True, upper_inclusive=True + ) + + # 区间范围近似相同 + assert Interval(0, 1e10) == Interval(0, 1e10 + 1) + # 区间范围近似相同,但另一个区间包含边界 + assert Interval(0, 1e10) != Interval(0, 1e10 + 1, lower_inclusive=True) + + with pytest.raises(RuntimeError): + assert Interval(0, 1) == "Interval(0, 1)" + + def test_repr(self): + assert repr(Interval(0, 1)) == "(0, 1)" + assert repr(Interval(0, 1, lower_inclusive=True)) == "[0, 1)" + assert repr(Interval(0, 1, upper_inclusive=True)) == "(0, 1]" + assert repr(Interval(0, 1, lower_inclusive=True, upper_inclusive=True)) == "[0, 1]" + + def test_pseudo_lbound(self): + assert Interval(0, 1).pseudo_lbound() == 1e-10 + assert Interval(0, 1, lower_inclusive=True).pseudo_lbound() == 0 + + def test_pseudo_ubound(self): + assert Interval(0, 1).pseudo_ubound() == 1 - 1e-10 + assert Interval(0, 1, upper_inclusive=True).pseudo_ubound() == 1 + + def test_pseudo_bound(self): + assert Interval(0, 1).pseudo_bound() == (1e-10, 1 - 1e-10) + assert Interval(0, 1, lower_inclusive=True).pseudo_bound() == (0, 1 - 1e-10) + assert Interval(0, 1, upper_inclusive=True).pseudo_bound() == (1e-10, 1) + assert Interval(0, 1, lower_inclusive=True, upper_inclusive=True).pseudo_bound() == (0, 1) + + +class TestPowerAnalysisNumeric: + def test_domain(self): + assert PowerAnalysisNumeric._domain == Interval(-inf, inf, lower_inclusive=True, upper_inclusive=True) + + def test_init(self): + assert PowerAnalysisNumeric(0) == 0 + assert PowerAnalysisNumeric(0.5) == 0.5 + assert PowerAnalysisNumeric(1) == 1 + assert PowerAnalysisNumeric(-inf) == -inf + assert PowerAnalysisNumeric(inf) == inf + + with pytest.raises(TypeError): + PowerAnalysisNumeric("0.5") + with pytest.raises(ValueError): + PowerAnalysisNumeric(nan) + + def test_repr(self): + assert repr(PowerAnalysisNumeric(0)) == "PowerAnalysisNumeric(0)" + assert repr(PowerAnalysisNumeric(0.5)) == "PowerAnalysisNumeric(0.5)" + assert repr(PowerAnalysisNumeric(1)) == "PowerAnalysisNumeric(1)" + assert repr(PowerAnalysisNumeric(-inf)) == "PowerAnalysisNumeric(-inf)" + assert repr(PowerAnalysisNumeric(inf)) == "PowerAnalysisNumeric(inf)" + + def test_add(self): + assert PowerAnalysisNumeric(1) + 1 == 2 + assert PowerAnalysisNumeric(1) + PowerAnalysisNumeric(1) == 2 + + assert PowerAnalysisNumeric(1) + 0.5 == 1.5 + assert PowerAnalysisNumeric(1) + PowerAnalysisNumeric(0.5) == 1.5 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(1) + "1" + + def test_radd(self): + assert 1 + PowerAnalysisNumeric(1) == 2 + assert 1 + PowerAnalysisNumeric(0.5) == 1.5 + + with pytest.raises(TypeError): + "1" + PowerAnalysisNumeric(1) + + def test_sub(self): + assert PowerAnalysisNumeric(1) - 1 == 0 + assert PowerAnalysisNumeric(1) - PowerAnalysisNumeric(1) == 0 + + assert PowerAnalysisNumeric(1) - 0.5 == 0.5 + assert PowerAnalysisNumeric(1) - PowerAnalysisNumeric(0.5) == 0.5 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(1) - "1" + + def test_rsub(self): + assert 1 - PowerAnalysisNumeric(1) == 0 + assert 1 - PowerAnalysisNumeric(0.5) == 0.5 + + with pytest.raises(TypeError): + "1" - PowerAnalysisNumeric(1) + + def test_mul(self): + assert PowerAnalysisNumeric(2) * 4 == 8 + assert PowerAnalysisNumeric(2) * PowerAnalysisNumeric(4) == 8 + + assert PowerAnalysisNumeric(2) * 0.5 == 1 + assert PowerAnalysisNumeric(2) * PowerAnalysisNumeric(0.5) == 1 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(2) * "4" + + def test_rmul(self): + assert 2 * PowerAnalysisNumeric(4) == 8 + assert 2 * PowerAnalysisNumeric(0.5) == 1 + + with pytest.raises(TypeError): + "4" * PowerAnalysisNumeric(2) + + def test_truediv(self): + assert PowerAnalysisNumeric(1) / 2 == 0.5 + assert PowerAnalysisNumeric(1) / PowerAnalysisNumeric(2) == 0.5 + + assert PowerAnalysisNumeric(1) / 0.5 == 2 + assert PowerAnalysisNumeric(1) / PowerAnalysisNumeric(0.5) == 2 + + assert PowerAnalysisNumeric(1) / 3 == 1 / 3 + assert PowerAnalysisNumeric(1) / PowerAnalysisNumeric(3) == 1 / 3 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(1) / "2" + + def test_rtruediv(self): + assert 1 / PowerAnalysisNumeric(2) == 0.5 + assert 1 / PowerAnalysisNumeric(0.5) == 2 + + assert 1 / PowerAnalysisNumeric(3) == 1 / 3 + + with pytest.raises(TypeError): + "2" / PowerAnalysisNumeric(1) + + def test_floordiv(self): + assert PowerAnalysisNumeric(3) // 2 == 1 + assert PowerAnalysisNumeric(3) // PowerAnalysisNumeric(2) == 1 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(3) // "2" + + def test_rfloordiv(self): + assert 3 // PowerAnalysisNumeric(2) == 1 + + with pytest.raises(TypeError): + "3" // PowerAnalysisNumeric(2) + + def test_mod(self): + assert PowerAnalysisNumeric(5) % 3 == 2 + assert PowerAnalysisNumeric(5) % PowerAnalysisNumeric(3) == 2 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(5) % "3" + + def test_rmod(self): + assert 3 % PowerAnalysisNumeric(2) == 1 + + class Person: + pass + + with pytest.raises(TypeError): + Person() % PowerAnalysisNumeric(2) + + def test_pow(self): + assert PowerAnalysisNumeric(2) ** 3 == 8 + assert PowerAnalysisNumeric(2) ** PowerAnalysisNumeric(3) == 8 + + assert PowerAnalysisNumeric(2) ** 0.5 == 2**0.5 + + with pytest.raises(TypeError): + PowerAnalysisNumeric(2) ** "3" + + def test_rpow(self): + assert 2 ** PowerAnalysisNumeric(3) == 8 + assert 2 ** PowerAnalysisNumeric(0.5) == 2**0.5 + + with pytest.raises(TypeError): + "2" ** PowerAnalysisNumeric(3) + + def test_abs(self): + assert abs(PowerAnalysisNumeric(3)) == 3 + assert abs(PowerAnalysisNumeric(-3)) == 3 + + def test_neg(self): + assert -PowerAnalysisNumeric(3) == -3 + assert -PowerAnalysisNumeric(-3) == 3 + + def test_pos(self): + assert +PowerAnalysisNumeric(3) == 3 + assert +PowerAnalysisNumeric(-3) == -3 + + def test_trunc(self): + assert trunc(PowerAnalysisNumeric(3.14)) == 3 + assert trunc(PowerAnalysisNumeric(-3.14)) == -3 + + def test_int(self): + assert int(PowerAnalysisNumeric(3.14)) == 3 + assert int(PowerAnalysisNumeric(-3.14)) == -3 + + def test_floor(self): + assert floor(PowerAnalysisNumeric(3.14)) == 3 + assert floor(PowerAnalysisNumeric(-3.14)) == -4 + + def test_ceil(self): + assert ceil(PowerAnalysisNumeric(3.14)) == 4 + assert ceil(PowerAnalysisNumeric(-3.14)) == -3 + + def test_round(self): + assert round(PowerAnalysisNumeric(3.1415)) == 3 + assert round(PowerAnalysisNumeric(-3.1415)) == -3 + assert round(PowerAnalysisNumeric(3.1415), 3) == 3.142 + assert round(PowerAnalysisNumeric(-3.1415), 3) == -3.142 + + def test_eq(self): + assert PowerAnalysisNumeric(0) == 0 + assert PowerAnalysisNumeric(0) == PowerAnalysisNumeric(0) + assert 0 == PowerAnalysisNumeric(0) + assert PowerAnalysisNumeric(0.5) == 0.5 + assert PowerAnalysisNumeric(0.5) == PowerAnalysisNumeric(0.5) + assert 0.5 == PowerAnalysisNumeric(0.5) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(0) == "0" + + def test_ne(self): + assert PowerAnalysisNumeric(0) != 1 + assert PowerAnalysisNumeric(0) != PowerAnalysisNumeric(1) + assert 0 != PowerAnalysisNumeric(1) + assert PowerAnalysisNumeric(0.5) != 0.6 + assert PowerAnalysisNumeric(0.5) != PowerAnalysisNumeric(0.6) + assert 0.5 != PowerAnalysisNumeric(0.6) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(0) != "0" + + def test_lt(self): + assert PowerAnalysisNumeric(0) < 1 + assert PowerAnalysisNumeric(0) < PowerAnalysisNumeric(1) + assert 0 < PowerAnalysisNumeric(1) + assert PowerAnalysisNumeric(0.5) < 0.6 + assert PowerAnalysisNumeric(0.5) < PowerAnalysisNumeric(0.6) + assert 0.5 < PowerAnalysisNumeric(0.6) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(0) < "1" + + def test_le(self): + assert PowerAnalysisNumeric(0) <= 1 + assert PowerAnalysisNumeric(0) <= PowerAnalysisNumeric(1) + assert 0 <= PowerAnalysisNumeric(1) + assert PowerAnalysisNumeric(0.5) <= 0.5 + assert PowerAnalysisNumeric(0.5) <= PowerAnalysisNumeric(0.5) + assert 0.5 <= PowerAnalysisNumeric(0.5) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(0) <= "1" + + def test_gt(self): + assert PowerAnalysisNumeric(1) > 0 + assert PowerAnalysisNumeric(1) > PowerAnalysisNumeric(0) + assert 1 > PowerAnalysisNumeric(0) + assert PowerAnalysisNumeric(0.6) > 0.5 + assert PowerAnalysisNumeric(0.6) > PowerAnalysisNumeric(0.5) + assert 0.6 > PowerAnalysisNumeric(0.5) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(1) > "0" + + def test_ge(self): + assert PowerAnalysisNumeric(1) >= 0 + assert PowerAnalysisNumeric(1) >= PowerAnalysisNumeric(0) + assert 1 >= PowerAnalysisNumeric(0) + assert PowerAnalysisNumeric(0.6) >= 0.5 + assert PowerAnalysisNumeric(0.6) >= PowerAnalysisNumeric(0.5) + assert 0.6 >= PowerAnalysisNumeric(0.5) + + with pytest.raises(RuntimeError): + PowerAnalysisNumeric(1) >= "0" + + def test_float(self): + assert float(PowerAnalysisNumeric(3)) == float(3.0) + assert float(PowerAnalysisNumeric(-3)) == float(-3.0) + + def test_complex(self): + assert complex(PowerAnalysisNumeric(3)) == complex(3.0) + assert complex(PowerAnalysisNumeric(-3)) == complex(-3.0) + + def test_hash(self): + assert hash(PowerAnalysisNumeric(3)) == hash(3.0) + assert hash(PowerAnalysisNumeric(inf)) == hash(inf) + assert hash(PowerAnalysisNumeric(-inf)) == hash(-inf) + + def test_bool(self): + assert bool(PowerAnalysisNumeric(3)) == bool(3.0) + assert bool(PowerAnalysisNumeric(0)) == bool(0.0) + assert bool(PowerAnalysisNumeric(-3)) == bool(-3.0) + assert bool(PowerAnalysisNumeric(inf)) == bool(inf) + assert bool(PowerAnalysisNumeric(-inf)) == bool(-inf) + + +def test_alpha(): + assert Alpha(0.05) == 0.05 + + with pytest.raises(ValueError): + Alpha(-1) + with pytest.raises(ValueError): + Alpha(0) + with pytest.raises(ValueError): + Alpha(1) + + assert repr(Alpha(0.05)) == "Alpha(0.05)" + + +def test_power(): + assert Power(0.05) == 0.05 + + with pytest.raises(ValueError): + Power(-1) + with pytest.raises(ValueError): + Power(0) + with pytest.raises(ValueError): + Power(1) + + assert repr(Power(0.05)) == "Power(0.05)" + + +def test_mean(): + assert Mean(0) == 0 + + with pytest.raises(ValueError): + Mean(-inf) + with pytest.raises(ValueError): + Mean(inf) + + assert repr(Mean(0)) == "Mean(0)" + + +def test_std(): + assert STD(10) == 10 + + with pytest.raises(ValueError): + STD(-10) + with pytest.raises(ValueError): + STD(0) + with pytest.raises(ValueError): + STD(inf) + + assert repr(STD(10)) == "STD(10)" + + +def test_proportion(): + assert Proportion(0.5) == 0.5 + + with pytest.raises(ValueError): + Proportion(-1) + with pytest.raises(ValueError): + Proportion(0) + with pytest.raises(ValueError): + Proportion(1) + + assert repr(Proportion(0.5)) == "Proportion(0.5)" + + +def test_percent(): + assert Percent(0.5) == 0.5 + + with pytest.raises(ValueError): + Percent(-1) + with pytest.raises(ValueError): + Percent(0) + with pytest.raises(ValueError): + Percent(1) + + assert repr(Percent(0.5)) == "Percent(0.5)" + + +def test_ratio(): + assert Ratio(0.5) == 0.5 + + with pytest.raises(ValueError): + Ratio(-1) + with pytest.raises(ValueError): + Ratio(0) + with pytest.raises(ValueError): + Ratio(inf) + + assert repr(Ratio(0.5)) == "Ratio(0.5)" + + +def test_size(): + assert Size(20) == 20 + assert Size(20.142857) == 20.142857 + + with pytest.raises(ValueError): + Size(-1) + with pytest.raises(ValueError): + Size(0) + with pytest.raises(ValueError): + Size(inf) + + assert repr(Size(20)) == "Size(20)" + + +def test_dropout_rate(): + assert DropOutRate(0) == 0 + assert DropOutRate(0.5) == 0.5 + + with pytest.raises(ValueError): + DropOutRate(-1) + with pytest.raises(ValueError): + DropOutRate(1) + + assert repr(DropOutRate(0.5)) == "DropOutRate(0.5)" + + +def test_mix(): + alpha = Alpha(0.05) + power = Power(0.8) + mean = Mean(0) + std = STD(10) + proportion = Proportion(0.5) + percent = Percent(0.5) + ratio = Ratio(0.5) + size = Size(7) + dropout_rate = DropOutRate(0.5) + + assert alpha + power == 0.8 + 0.05 + assert alpha - mean == 0.05 - 0 + assert power * std == 0.8 * 10 + assert std / proportion == 10 / 0.5 + assert power // percent == 0.8 // 0.5 + assert ratio % size == 0.5 % 7 + assert std**dropout_rate == 3.1622776601683795