Skip to content

Commit

Permalink
refactor(basic): custom numeric class (#25)
Browse files Browse the repository at this point in the history
Custom a numeric class inherited from numbers.Real, and move the Intercval class
  • Loading branch information
Snoopy1866 authored Sep 22, 2024
1 parent 15ce61e commit 19b56f0
Show file tree
Hide file tree
Showing 4 changed files with 748 additions and 88 deletions.
296 changes: 296 additions & 0 deletions src/pystatpower/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
from math import ceil, floor, inf, isclose, trunc
from numbers import Real


from dataclasses import dataclass


@dataclass(frozen=True)
class Interval:
"""定义一个区间,可指定是否包含上下限,不支持单点区间(例如:[1, 1])。
Parameters
----------
lower (Real): 区间下限
upper (Real): 区间上限
lower_inclusive (bool): 是否包含区间下限
upper_inclusive (bool): 是否包含区间上限
Examples
--------
>>> interval = Interval(0, 1, lower_inclusive=True, upper_inclusive=False)
>>> 0.5 in interval
True
>>> 1 in interval
False
>>> 0 in interval
False
>>> interval.pseudo_bound()
(0, 0.9999999999)
"""

lower: Real
upper: Real
lower_inclusive: bool = False
upper_inclusive: bool = False

def __contains__(self, value: Real) -> bool:
if isinstance(value, Real):
if self.lower_inclusive:
if self.upper_inclusive:
return self.lower <= value <= self.upper
else:
return self.lower <= value < self.upper
else:
if self.upper_inclusive:
return self.lower < value <= self.upper
else:
return self.lower < value < self.upper

raise RuntimeError(f"Interval.__contains__ only supports real numbers, but you passed in a {type(value)}.")

def __eq__(self, other: object) -> bool:
if isinstance(other, Interval):
return (
isclose(self.lower, other.lower)
and isclose(self.upper, other.upper)
and self.lower_inclusive == other.lower_inclusive
and self.upper_inclusive == other.upper_inclusive
)

raise RuntimeError(f"Interval.__eq__ only supports Interval, but you passed in a {type(other)}.")

def __repr__(self) -> str:
if self.lower_inclusive:
if self.upper_inclusive:
return f"[{self.lower}, {self.upper}]"
else:
return f"[{self.lower}, {self.upper})"
else:
if self.upper_inclusive:
return f"({self.lower}, {self.upper}]"
else:
return f"({self.lower}, {self.upper})"

def pseudo_lbound(self, eps: Real = 1e-10) -> Real:
"""区间的伪下界,用于数值计算。"""
if self.lower_inclusive:
return self.lower
else:
return self.lower + eps

def pseudo_ubound(self, eps: Real = 1e-10) -> Real:
"""区间的伪上界,用于数值计算。"""
if self.upper_inclusive:
return self.upper
else:
return self.upper - eps

def pseudo_bound(self, eps: Real = 1e-10) -> tuple[Real, Real]:
"""区间的伪上下界,用于数值计算。"""
return (self.pseudo_lbound(eps), self.pseudo_ubound(eps))


class PowerAnalysisNumeric(Real):

_domain = Interval(-inf, inf, lower_inclusive=True, upper_inclusive=True)

def __init__(self, value: Real):
if not isinstance(value, Real):
raise TypeError(f"{value} is not a real number.")
if value not in type(self)._domain:
raise ValueError(f"{value} is not in {type(self)._domain}.")
self._value = value

def __repr__(self):
return f"{type(self).__name__}({self._value})"

def __add__(self, other):
if isinstance(other, Real):
return self._value + other
return NotImplemented

def __radd__(self, other):
if isinstance(other, Real):
return other + self._value
return NotImplemented

def __sub__(self, other):
if isinstance(other, Real):
return self._value - other
return NotImplemented

def __rsub__(self, other):
if isinstance(other, Real):
return other - self._value
return NotImplemented

def __mul__(self, other):
if isinstance(other, Real):
return self._value * other
return NotImplemented

def __rmul__(self, other):
if isinstance(other, Real):
return other * self._value
return NotImplemented

def __truediv__(self, other):
if isinstance(other, Real):
return self._value / other
return NotImplemented

def __rtruediv__(self, other):
if isinstance(other, Real):
return other / self._value
return NotImplemented

def __floordiv__(self, other):
if isinstance(other, Real):
return self._value // other
return NotImplemented

def __rfloordiv__(self, other):
if isinstance(other, Real):
return other // self._value
return NotImplemented

def __mod__(self, other):
if isinstance(other, Real):
return self._value % other
return NotImplemented

def __rmod__(self, other):
if isinstance(other, Real):
return other % self._value
return NotImplemented

def __pow__(self, other):
if isinstance(other, Real):
return self._value**other
return NotImplemented

def __rpow__(self, base):
if isinstance(base, Real):
return base**self._value
return NotImplemented

def __abs__(self):
return abs(self._value)

def __neg__(self):
return -self._value

def __pos__(self):
return +self._value

def __trunc__(self):
return trunc(self._value)

def __floor__(self):
return floor(self._value)

def __ceil__(self):
return ceil(self._value)

def __round__(self, ndigits=None):
return round(self._value, ndigits)

def __eq__(self, other):
if isinstance(other, Real):
return self._value == other
raise RuntimeError(f"{type(self)}.__eq__ only supports real numbers, but you passed in a {type(other)}.")

def __ne__(self, other):
if isinstance(other, Real):
return self._value != other
raise RuntimeError(f"{type(self)}.__ne__ only supports real numbers, but you passed in a {type(other)}.")

def __lt__(self, other):
if isinstance(other, Real):
return self._value < other
raise RuntimeError(f"{type(self)}.__lt__ only supports real numbers, but you passed in a {type(other)}.")

def __le__(self, other):
if isinstance(other, Real):
return self._value <= other
raise RuntimeError(f"{type(self)}.__le__ only supports real numbers, but you passed in a {type(other)}.")

def __gt__(self, other):
if isinstance(other, Real):
return self._value > other
raise RuntimeError(f"{type(self)}.__gt__ only supports real numbers, but you passed in a {type(other)}.")

def __ge__(self, other):
if isinstance(other, Real):
return self._value >= other
raise RuntimeError(f"{type(self)}.__ge__ only supports real numbers, but you passed in a {type(other)}.")

def __int__(self):
return int(self._value)

def __float__(self):
return float(self._value)

def __complex__(self):
return complex(self._value)

def __hash__(self):
return hash(self._value)

def __bool__(self):
return bool(self._value)


class Alpha(PowerAnalysisNumeric):
"""显著性水平"""

_domain = Interval(0, 1)


class Power(PowerAnalysisNumeric):
"""检验效能"""

_domain = Interval(0, 1)


class Mean(PowerAnalysisNumeric):
"""均值"""

_domain = Interval(-inf, inf)


class STD(PowerAnalysisNumeric):
"""标准差"""

_domain = Interval(0, inf)


class Proportion(PowerAnalysisNumeric):
"""率"""

_domain = Interval(0, 1)


class Percent(PowerAnalysisNumeric):
"""百分比"""

_domain = Interval(0, 1)


class Ratio(PowerAnalysisNumeric):
"""比例"""

_domain = Interval(0, inf)


class Size(PowerAnalysisNumeric):
"""样本量"""

_domain = Interval(0, inf)


class DropOutRate(PowerAnalysisNumeric):
"""脱落率"""

_domain = Interval(0, 1, lower_inclusive=True)
87 changes: 0 additions & 87 deletions src/pystatpower/interval.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/pystatpower/procedures/one_sample_proportion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TargetParameterNotExistError,
)
from pystatpower.utils import get_enum_by_name
from pystatpower.interval import Interval
from pystatpower.basic import Interval


# 最大样本量
Expand Down
Loading

0 comments on commit 19b56f0

Please sign in to comment.