diff --git a/src/differences/attgt/attgt_cal.py b/src/differences/attgt/attgt_cal.py index e86701e..e8855b0 100644 --- a/src/differences/attgt/attgt_cal.py +++ b/src/differences/attgt/attgt_cal.py @@ -17,33 +17,34 @@ from ..tools.panel_utility import panel_2_cross_section_diff from ..tools.utility import tqdm_joblib + # todo: FIX important, for unbalanced panels the cohort dummy and the rest has, when cluster var # ------------ att and influence function for a single ct -------------- def did_single_gt( - data: DataFrame, - entities: Index, - entity_name: str, - n_total: int, - y_name: str, - cohort_name: str, - strata_name: str, - weights_name: str, - x_covariates: list, - x_base: list, - x_delta: list, - anticipation: int, - control_group: str, - is_panel: bool, - is_balanced_panel: bool, - cluster_by_entity: bool, - att_ct_func: Callable, - # will be the iterable to iterate over - cohort: int, - base_period: int, - time: int, - stratum: str | float | int = None, + data: DataFrame, + entities: Index, + entity_name: str, + n_total: int, + y_name: str, + cohort_name: str, + strata_name: str, + weights_name: str, + x_covariates: list, + x_base: list, + x_delta: list, + anticipation: int, + control_group: str, + is_panel: bool, + is_balanced_panel: bool, + cluster_by_entity: bool, + att_ct_func: Callable, + # will be the iterable to iterate over + cohort: int, + base_period: int, + time: int, + stratum: str | float | int = None, ) -> namedtuple: """ computes the ATT for a single cohort-time @@ -84,7 +85,7 @@ def did_single_gt( if strata_name is not None: cohort_stratum_mask = (data[cohort_name] == cohort) & ( - data[strata_name] == stratum + data[strata_name] == stratum ) cohort_stratum_dummy = csr_array( @@ -117,7 +118,7 @@ def did_single_gt( mask_treated = (data[cohort_name] == cohort).values else: mask_treated = ( - (data[cohort_name] == cohort) & (data[strata_name] == stratum) + (data[cohort_name] == cohort) & (data[strata_name] == stratum) ).values mask_control = (data[cohort_name].isnull()).values # never treated @@ -127,14 +128,14 @@ def did_single_gt( maxt_ = max(time, base_period) mask_control = mask_control | ( - (~mask_treated) & (data[cohort_name] > maxt_ + anticipation).values + (~mask_treated) & (data[cohort_name] > maxt_ + anticipation).values ) # ---------------- pre / post (base period / time) ----------------- times_values = data.index.get_level_values(1) mask_pre_post = (times_values == base_period) | ( - times_values == time + times_values == time ) # pre or post # must be treated or control & must be observed pre or post @@ -144,7 +145,7 @@ def did_single_gt( # cluster_by_entity should be False for true rc (would be redundant) if ( - not is_balanced_panel and cluster_by_entity + not is_balanced_panel and cluster_by_entity ): # panels (as rc) but entity level if # cohort_dummy needs to be defined before masking for unbalanced panels, @@ -341,26 +342,26 @@ def did_single_gt( def get_att_gt( - group_time: list[dict], - data: DataFrame, - y_name: str, - cohort_name: str, - is_panel: bool, - is_balanced_panel: bool, - cluster_by_entity: bool, - x_covariates: list, - x_base: list, - x_delta: list, - strata_name: str = None, - weights_name: str = None, - control_group: str = "never_treated", - anticipation: int = 0, - att_function_ct: Callable = None, - n_jobs_ct: int = 1, - backend_ct: str = "loky", - progress_bar: bool = True, - sample_name: str = None, - release_workers: bool = True, + group_time: list[dict], + data: DataFrame, + y_name: str, + cohort_name: str, + is_panel: bool, + is_balanced_panel: bool, + cluster_by_entity: bool, + x_covariates: list, + x_base: list, + x_delta: list, + strata_name: str = None, + weights_name: str = None, + control_group: str = "never_treated", + anticipation: int = 0, + att_function_ct: Callable = None, + n_jobs_ct: int = 1, + backend_ct: str = "loky", + progress_bar: bool = True, + sample_name: str = None, + release_workers: bool = True, ) -> list[namedtuple]: if control_group not in ["never_treated", "not_yet_treated"]: raise ValueError( @@ -380,57 +381,89 @@ def get_att_gt( jobs = cpu_count() + 1 - n_jobs_ct if n_jobs_ct < 0 else n_jobs_ct sample_name = f"for {sample_name} " if sample_name else "" - with tqdm_joblib( - tqdm( - disable=not progress_bar, - desc=f"Computing ATTgt {sample_name}[workers={jobs}]", - total=len(group_time), - bar_format="{desc:<30}{percentage:3.0f}%|{bar:20}{r_bar}", - ) - ): - - # compute the ct att. order of parameters must conform to did_single_gt() - res_ntl = Parallel(n_jobs=n_jobs_ct, backend=backend_ct,)( - delayed(did_single_gt)( - data, - entities, - entity_name, - n_total, - y_name, - cohort_name, - strata_name, - weights_name, - x_covariates, - x_base, - x_delta, - anticipation, - control_group, - is_panel, - is_balanced_panel, - cluster_by_entity, - att_function_ct, - **gt, + tqdm_object = tqdm( + disable=not progress_bar, + desc=f"Computing ATTgt {sample_name}[workers={jobs}]", + total=len(group_time), + bar_format="{desc:<30}{percentage:3.0f}%|{bar:20}{r_bar}", + ) + + if jobs > 1: # tqdm_joblib doesn't work anymore for jobs=1 + with tqdm_joblib(tqdm_object): + + # compute the ct att. order of parameters must conform to did_single_gt() + # compute the ct att. order of parameters must conform to did_single_gt() + res_ntl = Parallel(n_jobs=n_jobs_ct, backend=backend_ct, )( + delayed(did_single_gt)( + data, + entities, + entity_name, + n_total, + y_name, + cohort_name, + strata_name, + weights_name, + x_covariates, + x_base, + x_delta, + anticipation, + control_group, + is_panel, + is_balanced_panel, + cluster_by_entity, + att_function_ct, + **gt, + ) + for gt in group_time ) - for gt in group_time - ) - if release_workers: - get_reusable_executor().shutdown(wait=True) + if release_workers: + get_reusable_executor().shutdown(wait=True) + + else: # n_jobs=1 + + with tqdm_joblib(tqdm_object): + + res_ntl = [] + for gt in group_time: + res_ntl.append( + did_single_gt( + data, + entities, + entity_name, + n_total, + y_name, + cohort_name, + strata_name, + weights_name, + x_covariates, + x_base, + x_delta, + anticipation, + control_group, + is_panel, + is_balanced_panel, + cluster_by_entity, + att_function_ct, + **gt, + ) + ) + tqdm_object.update(1) return res_ntl def get_standard_errors( - ntl: list[namedtuple], - cluster_groups: np.ndarray = None, - alpha: float = 0.05, - boot_iterations: int = 0, - random_state: int = None, - backend_boot: str = "loky", - n_jobs_boot: int = -1, - progress_bar: bool = True, - sample_name: str = None, - release_workers: bool = True, + ntl: list[namedtuple], + cluster_groups: np.ndarray = None, + alpha: float = 0.05, + boot_iterations: int = 0, + random_state: int = None, + backend_boot: str = "loky", + n_jobs_boot: int = -1, + progress_bar: bool = True, + sample_name: str = None, + release_workers: bool = True, ) -> list[namedtuple]: if boot_iterations < 0: raise ValueError( @@ -490,13 +523,13 @@ def get_standard_errors( def get_cohort_stratum_dummies( - data: DataFrame, - entities: Index | list, - cohort_name: str, - cohort: int, - strata_name: str, - stratum: str | int | float, - repeated_cross_section: bool = False, + data: DataFrame, + entities: Index | list, + cohort_name: str, + cohort: int, + strata_name: str, + stratum: str | int | float, + repeated_cross_section: bool = False, ) -> tuple: """helper for did_single_gt""" if repeated_cross_section: