Skip to content

Commit

Permalink
feat: add power_config capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
LGonzalezGomez committed Dec 17, 2024
1 parent dc09174 commit 239666e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
17 changes: 17 additions & 0 deletions cluster_experiments/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,11 +1263,16 @@ def __init__(
hypothesis=hypothesis,
)
self.scale_col = scale_col
self.cluster_cols = cluster_cols or []

if covariates is not None:
warnings.warn(
"Covariates are not supported in the Delta Method approximation for the time being. They will be ignored."
)
if cluster_cols is None:
raise ValueError(
"cluster_cols must be provided for the Delta Method analysis"
)

def _aggregate_to_cluster(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -1366,6 +1371,18 @@ def analysis_standard_error(self, df: pd.DataFrame) -> float:
_mean_diff, standard_error = self._get_mean_standard_error(df)
return standard_error

@classmethod
def from_config(cls, config):
"""Creates a DeltaMethodAnalysis object from a PowerConfig object"""
return cls(
cluster_cols=config.cluster_cols,
target_col=config.target_col,
scale_col=config.scale_col,
treatment_col=config.treatment_col,
treatment=config.treatment,
hypothesis=config.hypothesis,
)

def __warn_small_group_size(self):
warnings.warn(
"Delta Method approximation may not be accurate for small group sizes"
Expand Down
16 changes: 16 additions & 0 deletions cluster_experiments/power_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class MissingArgumentError(ValueError):
pass


class UnexpectedArgumentError(ValueError):
pass


@dataclass(eq=True)
class PowerConfig:
"""
Expand Down Expand Up @@ -106,6 +110,7 @@ class PowerConfig:

# optional mappings
cupac_model: str = ""
scale_col: Optional[str] = None

# Shared
target_col: str = "target"
Expand Down Expand Up @@ -197,6 +202,10 @@ def __post_init__(self):
if "segmented" in self.perturbator:
self._raise_error_if_missing("segment_cols", "perturbator")

if "delta" not in self.analysis:
if self.scale_col is not None:
self._raise_error_if_missing("scale_col", "analysis")

def _are_different(self, arg1, arg2) -> bool:
return arg1 != arg2

Expand All @@ -215,6 +224,13 @@ def _raise_error_if_missing(self, attr, other_attr):
f"{other_attr} = {getattr(self, other_attr)}."
)

def _raise_error_if_present(self, attr, other_attr):
if getattr(self, attr) is None:
raise UnexpectedArgumentError(
f"{attr} is not expected when using "
f"{other_attr} = {getattr(self, other_attr)}."
)


perturbator_mapping = {
"binary": BinaryPerturbator,
Expand Down

0 comments on commit 239666e

Please sign in to comment.