diff --git a/main.py b/main.py new file mode 100644 index 0000000..5fc876a --- /dev/null +++ b/main.py @@ -0,0 +1,60 @@ +from pystatpower.procedures.two_proportion import * + +result = ( + TwoProportion(EnumSolvableParameter.N) + .set_alpha(0.05) + .set_power(0.8) + .set_alternative(EnumAlternative.TWO_SIDED) + .set_group_allocation(GroupAllocation(EnumGroupAllocation.EQUAL)) + .set_test_type(EnumTestType.Z_TEST_POOLED) + .set_treatment_proportion(0.95) + .set_reference_proportion(0.80) + .get_solver() + .solve() +) + +print(result) + + +result = ( + TwoProportion(EnumSolvableParameter.N) + .set_alpha(0.05) + .set_power(0.8) + .set_alternative(EnumAlternative.TWO_SIDED) + .set_group_allocation(GroupAllocation(EnumGroupAllocation.FIX_TREATMENT_GROUP).set_treatment_n(100)) + .set_test_type(EnumTestType.Z_TEST_POOLED) + .set_treatment_proportion(0.95) + .set_reference_proportion(0.80) + .get_solver() + .solve() +) + +print(result) + + +result = ( + TwoProportion(EnumSolvableParameter.N) + .set_alpha(0.05) + .set_power(0.8) + .set_alternative(EnumAlternative.TWO_SIDED) + .set_group_allocation(GroupAllocation(EnumGroupAllocation.EQUAL)) + .set_test_type(EnumTestType.Z_TEST_POOLED) + .set_treatment_proportion(0.68) + .set_reference_proportion(0.69) + .get_solver() + .solve() +) + +print(result) + + +result = fun_power( + alpha=0.05, + treatment_n=10000, + reference_n=10000, + treatment_proportion=0.68, + reference_proportion=0.69, + alternative=EnumAlternative.TWO_SIDED, + test_type=EnumTestType.Z_TEST_POOLED, +) +print(result) diff --git a/src/pystatpower/procedures/two_proportion.py b/src/pystatpower/procedures/two_proportion.py index 61ce69f..f5ef0b7 100644 --- a/src/pystatpower/procedures/two_proportion.py +++ b/src/pystatpower/procedures/two_proportion.py @@ -7,7 +7,7 @@ from scipy.optimize import brenth -class SolvableParameter(Enum): +class EnumSolvableParameter(Enum): """求解目标""" N = 1 @@ -17,14 +17,14 @@ class SolvableParameter(Enum): REFERENCE_PROPORTION = 5 -class Alternative(Enum): +class EnumAlternative(Enum): """假设检验的备择假设类型""" ONE_SIDED = 1 TWO_SIDED = 2 -class TestType(Enum): +class EnumTestType(Enum): """检验类型""" Z_TEST_POOLED = 1 @@ -33,7 +33,7 @@ class TestType(Enum): Z_TEST_CC_UNPOOLED = 4 -class GroupAllocation(Enum): +class EnumGroupAllocation(Enum): """样本量分配方式""" EQUAL = 1 @@ -45,14 +45,14 @@ class GroupAllocation(Enum): PERCENT_OF_REFERENCE = 7 -def _power( +def fun_power( alpha: float, treatment_n: float, reference_n: float, treatment_proportion: float, reference_proportion: float, - alternative: Alternative, - test_type: TestType, + alternative: EnumAlternative, + test_type: EnumTestType, ): n1 = treatment_n n2 = reference_n @@ -61,26 +61,26 @@ def _power( # 计算标准误 match test_type: - case TestType.Z_TEST_POOLED | TestType.Z_TEST_CC_POOLED: + case EnumTestType.Z_TEST_POOLED | EnumTestType.Z_TEST_CC_POOLED: p_hat = (n1 * p1 + n2 * p2) / (n1 + n2) se = sqrt(p_hat * (1 - p_hat) * (1 / n1 + 1 / n2)) - case TestType.Z_TEST_UNPOOLED | TestType.Z_TEST_CC_UNPOOLED: + case EnumTestType.Z_TEST_UNPOOLED | EnumTestType.Z_TEST_CC_UNPOOLED: se = sqrt(p1 * (1 - p1) / n1 + p2 * (1 - p2) / n2) case _: assert False, "未知的检验类型" # 连续性校正 c = 0 - if test_type in [TestType.Z_TEST_CC_POOLED, TestType.Z_TEST_CC_UNPOOLED]: + if test_type in [EnumTestType.Z_TEST_CC_POOLED, EnumTestType.Z_TEST_CC_UNPOOLED]: c = (1 / 2) * (1 / n1 + 1 / n2) # 计算检验效能 match alternative: - case Alternative.TWO_SIDED: + case EnumAlternative.TWO_SIDED: z_alpha = norm.ppf(1 - alpha / 2) z_stat = [(p1 - p2 - c) / se, (p1 - p2 + c) / se] power = norm.cdf(-z_alpha - z_stat[0]) + 1 - norm.cdf(z_alpha - z_stat[1]) - case Alternative.ONE_SIDED: + case EnumAlternative.ONE_SIDED: z_alpha = norm.ppf(1 - alpha) if p1 > p2: z_stat = (p1 - p2 + c) / se @@ -94,22 +94,45 @@ def _power( return power -class TwoProportionDesigner: - def __init__(self, solve_for: SolvableParameter): - if not isinstance(solve_for, SolvableParameter): - raise TypeError("solve_for must be an instance of SolvableParameter") +class GroupAllocation: + def __init__(self, group_allocation_option: EnumGroupAllocation): + self._group_allocation_option = group_allocation_option - match solve_for: - case SolvableParameter.N: - return TwoProportionSolveForNDesigner() - case SolvableParameter.ALPHA: - return TwoProportionSolveForAlphaDesigner() - case SolvableParameter.POWER: - return TwoProportionSolveForPowerDesigner() - case SolvableParameter.TREATMENT_PROPORTION: - return TwoProportionSolveForTreatmentProportionDesigner() - case SolvableParameter.REFERENCE_PROPORTION: - return TwoProportionSolveForReferenceProportionDesigner() + def set_treatment_n(self, treatment_n: float): + if self._group_allocation_option != EnumGroupAllocation.FIX_TREATMENT_GROUP: + raise ValueError("treatment_n 只能在 group_allocation 为 FIX_TREATMENT_GROUP 时指定") + self._treatment_n = treatment_n + return self + + def set_reference_n(self, reference_n: float): + if self._group_allocation_option != EnumGroupAllocation.FIX_REFERENCE_GROUP: + raise ValueError("reference_n 只能在 group_allocation 为 FIX_REFERENCE_GROUP 时指定") + self._reference_n = reference_n + return self + + def set_ratio_of_treatment_to_reference(self, ratio_of_treatment_to_reference: float): + if self._group_allocation_option != EnumGroupAllocation.RATIO_OF_TREATMENT_TO_REFERENCE: + raise ValueError("ratio 只能在 group_allocation 为 RATIO_OF_TREATMENT_TO_REFERENCE 时指定") + self._ratio_of_treatment_to_reference = ratio_of_treatment_to_reference + return self + + def set_ratio_of_reference_to_treatment(self, ratio_of_reference_to_treatment: float): + if self._group_allocation_option != EnumGroupAllocation.RATIO_OF_REFERENCE_TO_TREATMENT: + raise ValueError("ratio 只能在 group_allocation 为 RATIO_OF_REFERENCE_TO_TREATMENT 时指定") + self._ratio_of_reference_to_treatment = ratio_of_reference_to_treatment + return self + + def set_percent_of_treatment(self, percent_of_treatment: float): + if self._group_allocation_option != EnumGroupAllocation.PERCENT_OF_TREATMENT: + raise ValueError("percent_of_treatment 只能在 group_allocation 为 PERCENT_OF_TREATMENT 时指定") + self._percent_of_treatment = percent_of_treatment + return self + + def set_percent_of_reference(self, percent_of_reference: float): + if self._group_allocation_option != EnumGroupAllocation.PERCENT_OF_REFERENCE: + raise ValueError("percent_of_reference 只能在 group_allocation 为 PERCENT_OF_REFERENCE 时指定") + self._percent_of_reference = percent_of_reference + return self class TwoProportionSolveForNDesigner: @@ -125,11 +148,11 @@ def set_power(self, power: float = 0.80): self._config["power"] = power return self - def set_alternative(self, alternative: Alternative = Alternative.TWO_SIDED): + def set_alternative(self, alternative: EnumAlternative = EnumAlternative.TWO_SIDED): self._config["alternative"] = alternative return self - def set_test_type(self, test_type: TestType = TestType.Z_TEST_POOLED): + def set_test_type(self, test_type: EnumTestType = EnumTestType.Z_TEST_POOLED): self._config["test_type"] = test_type return self @@ -141,27 +164,164 @@ def set_reference_proportion(self, reference_proportion: float): self._config["reference_proportion"] = reference_proportion return self - def set_group_allocation(self, group_allocation: GroupAllocation = GroupAllocation.EQUAL): - match group_allocation: - case GroupAllocation.EQUAL: - pass - case GroupAllocation.FIX_TREATMENT_GROUP: - pass - case GroupAllocation.FIX_REFERENCE_GROUP: - pass - case GroupAllocation.RATIO_OF_TREATMENT_TO_REFERENCE: - pass - case GroupAllocation.RATIO_OF_REFERENCE_TO_TREATMENT: - pass - case GroupAllocation.PERCENT_OF_TREATMENT: - pass - case GroupAllocation.PERCENT_OF_REFERENCE: - pass + def set_group_allocation(self, group_allocation: GroupAllocation = GroupAllocation(EnumGroupAllocation.EQUAL)): + self._config["group_allocation"] = group_allocation return self def set_input_type(self, input_type): raise NotImplementedError("这个功能还没有实现") + def get_solver(self): + return TwoProportionSolveForNSolver(**self._config) + + +class TwoProportionSolveForNSolver: + def __init__( + self, + alpha: float, + power: float, + alternative: EnumAlternative, + test_type: EnumTestType, + treatment_proportion: float, + reference_proportion: float, + group_allocation: GroupAllocation, + ): + self._alpha = alpha + self._power = power + self._alternative = alternative + self._test_type = test_type + self._treatment_proportion = treatment_proportion + self._reference_proportion = reference_proportion + self._group_allocation = group_allocation + + def solve(self): + match self._group_allocation._group_allocation_option: + case EnumGroupAllocation.EQUAL: + eval = ( + lambda n: fun_power( + self._alpha, + n, + n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.FIX_TREATMENT_GROUP: + eval = ( + lambda n: fun_power( + self._alpha, + self._group_allocation._treatment_n, + n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.FIX_REFERENCE_GROUP: + eval = ( + lambda n: fun_power( + self._alpha, + n, + self._group_allocation._reference_n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.RATIO_OF_TREATMENT_TO_REFERENCE: + eval = ( + lambda n: fun_power( + self._alpha, + self._group_allocation._ratio_of_treatment_to_reference * n, + n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.RATIO_OF_REFERENCE_TO_TREATMENT: + eval = ( + lambda n: fun_power( + self._alpha, + n, + self._group_allocation._ratio_of_reference_to_treatment * n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.PERCENT_OF_TREATMENT: + eval = ( + lambda n: fun_power( + self._alpha, + n, + (1 - self._group_allocation._percent_of_treatment) + / self._group_allocation._percent_of_treatment + * n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + case EnumGroupAllocation.PERCENT_OF_REFERENCE: + eval = ( + lambda n: fun_power( + self._alpha, + (1 - self._group_allocation._percent_of_reference) + / self._group_allocation._percent_of_reference + * n, + n, + self._treatment_proportion, + self._reference_proportion, + self._alternative, + self._test_type, + ) + - self._power + ) + try: + n = brenth(eval, 1, 1e10) + except ValueError as e: + raise ValueError("无法求解样本量") from e + + return n + class TwoProportionSolveForAlphaDesigner: pass @@ -177,3 +337,24 @@ class TwoProportionSolveForTreatmentProportionDesigner: class TwoProportionSolveForReferenceProportionDesigner: pass + + +class TwoProportion: + def __new__(cls, solve_for: EnumSolvableParameter): + if not isinstance(solve_for, EnumSolvableParameter): + raise TypeError("solve_for must be an instance of SolvableParameter") + return cls._create_designer(solve_for) + + @staticmethod + def _create_designer(solve_for: EnumSolvableParameter): + match solve_for: + case EnumSolvableParameter.N: + return TwoProportionSolveForNDesigner() + case EnumSolvableParameter.ALPHA: + return TwoProportionSolveForAlphaDesigner() + case EnumSolvableParameter.POWER: + return TwoProportionSolveForPowerDesigner() + case EnumSolvableParameter.TREATMENT_PROPORTION: + return TwoProportionSolveForTreatmentProportionDesigner() + case EnumSolvableParameter.REFERENCE_PROPORTION: + return TwoProportionSolveForReferenceProportionDesigner()