From 239666ed06fe342dabbd71547b338d6a2da0a930 Mon Sep 17 00:00:00 2001 From: LGonzalezGomez Date: Tue, 17 Dec 2024 01:21:29 +0100 Subject: [PATCH] feat: add power_config capabilities --- cluster_experiments/experiment_analysis.py | 17 +++++++++++++++++ cluster_experiments/power_config.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/cluster_experiments/experiment_analysis.py b/cluster_experiments/experiment_analysis.py index cc1a7dc..43b76d7 100644 --- a/cluster_experiments/experiment_analysis.py +++ b/cluster_experiments/experiment_analysis.py @@ -1263,11 +1263,16 @@ def __init__( hypothesis=hypothesis, ) self.scale_col = scale_col + self.cluster_cols = cluster_cols or [] if covariates is not None: warnings.warn( "Covariates are not supported in the Delta Method approximation for the time being. They will be ignored." ) + if cluster_cols is None: + raise ValueError( + "cluster_cols must be provided for the Delta Method analysis" + ) def _aggregate_to_cluster(self, df: pd.DataFrame) -> pd.DataFrame: """ @@ -1366,6 +1371,18 @@ def analysis_standard_error(self, df: pd.DataFrame) -> float: _mean_diff, standard_error = self._get_mean_standard_error(df) return standard_error + @classmethod + def from_config(cls, config): + """Creates a DeltaMethodAnalysis object from a PowerConfig object""" + return cls( + cluster_cols=config.cluster_cols, + target_col=config.target_col, + scale_col=config.scale_col, + treatment_col=config.treatment_col, + treatment=config.treatment, + hypothesis=config.hypothesis, + ) + def __warn_small_group_size(self): warnings.warn( "Delta Method approximation may not be accurate for small group sizes" diff --git a/cluster_experiments/power_config.py b/cluster_experiments/power_config.py index de9ff0f..036838f 100644 --- a/cluster_experiments/power_config.py +++ b/cluster_experiments/power_config.py @@ -39,6 +39,10 @@ class MissingArgumentError(ValueError): pass +class UnexpectedArgumentError(ValueError): + pass + + @dataclass(eq=True) class PowerConfig: """ @@ -106,6 +110,7 @@ class PowerConfig: # optional mappings cupac_model: str = "" + scale_col: Optional[str] = None # Shared target_col: str = "target" @@ -197,6 +202,10 @@ def __post_init__(self): if "segmented" in self.perturbator: self._raise_error_if_missing("segment_cols", "perturbator") + if "delta" not in self.analysis: + if self.scale_col is not None: + self._raise_error_if_missing("scale_col", "analysis") + def _are_different(self, arg1, arg2) -> bool: return arg1 != arg2 @@ -215,6 +224,13 @@ def _raise_error_if_missing(self, attr, other_attr): f"{other_attr} = {getattr(self, other_attr)}." ) + def _raise_error_if_present(self, attr, other_attr): + if getattr(self, attr) is None: + raise UnexpectedArgumentError( + f"{attr} is not expected when using " + f"{other_attr} = {getattr(self, other_attr)}." + ) + perturbator_mapping = { "binary": BinaryPerturbator,