diff --git a/src/pystatpower/basic.py b/src/pystatpower/basic.py index 530af88..75dda13 100644 --- a/src/pystatpower/basic.py +++ b/src/pystatpower/basic.py @@ -1,3 +1,4 @@ +from enum import Enum, EnumMeta from math import ceil, floor, inf, isclose, trunc from numbers import Real @@ -92,6 +93,7 @@ def pseudo_bound(self, eps: Real = 1e-10) -> tuple[Real, Real]: class PowerAnalysisNumeric(Real): + """自定义功效分析数值类型""" _domain = Interval(-inf, inf, lower_inclusive=True, upper_inclusive=True) @@ -242,6 +244,16 @@ def __bool__(self): return bool(self._value) +class PowerAnalysisOption(EnumMeta): + """自定义功效分析选项的枚举元类,用于支持大小写不敏感的枚举值访问。""" + + def __getitem__(self, name): + if isinstance(name, str): + return super().__getitem__(name.upper()) + else: + return super().__getitem__(name) + + class Alpha(PowerAnalysisNumeric): """显著性水平""" diff --git a/tests/test_basic.py b/tests/test_basic.py index dcb98ce..9cecd79 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -316,6 +316,23 @@ def test_bool(self): assert bool(PowerAnalysisNumeric(-inf)) == bool(-inf) +class TestPowerAnalysisOption: + def test_getitem(self): + class TestEnum(Enum, metaclass=PowerAnalysisOption): + A = 1 + B = 2 + + assert TestEnum["A"] == TestEnum.A + assert TestEnum["a"] == TestEnum.A + assert TestEnum["B"] == TestEnum.B + assert TestEnum["b"] == TestEnum.B + + with pytest.raises(KeyError): + TestEnum["C"] + with pytest.raises(KeyError): + TestEnum[TestEnum.A] + + def test_alpha(): assert Alpha(0.05) == 0.05