Skip to content

Commit

Permalink
fix progress bar when n_jobs=1
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardodionisi committed Dec 9, 2023
1 parent 83188e0 commit 9e6a041
Showing 1 changed file with 131 additions and 98 deletions.
229 changes: 131 additions & 98 deletions src/differences/attgt/attgt_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,34 @@
from ..tools.panel_utility import panel_2_cross_section_diff
from ..tools.utility import tqdm_joblib


# todo: FIX important, for unbalanced panels the cohort dummy and the rest has, when cluster var
# ------------ att and influence function for a single ct --------------


def did_single_gt(
data: DataFrame,
entities: Index,
entity_name: str,
n_total: int,
y_name: str,
cohort_name: str,
strata_name: str,
weights_name: str,
x_covariates: list,
x_base: list,
x_delta: list,
anticipation: int,
control_group: str,
is_panel: bool,
is_balanced_panel: bool,
cluster_by_entity: bool,
att_ct_func: Callable,
# will be the iterable to iterate over
cohort: int,
base_period: int,
time: int,
stratum: str | float | int = None,
data: DataFrame,
entities: Index,
entity_name: str,
n_total: int,
y_name: str,
cohort_name: str,
strata_name: str,
weights_name: str,
x_covariates: list,
x_base: list,
x_delta: list,
anticipation: int,
control_group: str,
is_panel: bool,
is_balanced_panel: bool,
cluster_by_entity: bool,
att_ct_func: Callable,
# will be the iterable to iterate over
cohort: int,
base_period: int,
time: int,
stratum: str | float | int = None,
) -> namedtuple:
"""
computes the ATT for a single cohort-time
Expand Down Expand Up @@ -84,7 +85,7 @@ def did_single_gt(

if strata_name is not None:
cohort_stratum_mask = (data[cohort_name] == cohort) & (
data[strata_name] == stratum
data[strata_name] == stratum
)

cohort_stratum_dummy = csr_array(
Expand Down Expand Up @@ -117,7 +118,7 @@ def did_single_gt(
mask_treated = (data[cohort_name] == cohort).values
else:
mask_treated = (
(data[cohort_name] == cohort) & (data[strata_name] == stratum)
(data[cohort_name] == cohort) & (data[strata_name] == stratum)
).values

mask_control = (data[cohort_name].isnull()).values # never treated
Expand All @@ -127,14 +128,14 @@ def did_single_gt(
maxt_ = max(time, base_period)

mask_control = mask_control | (
(~mask_treated) & (data[cohort_name] > maxt_ + anticipation).values
(~mask_treated) & (data[cohort_name] > maxt_ + anticipation).values
)

# ---------------- pre / post (base period / time) -----------------

times_values = data.index.get_level_values(1)
mask_pre_post = (times_values == base_period) | (
times_values == time
times_values == time
) # pre or post

# must be treated or control & must be observed pre or post
Expand All @@ -144,7 +145,7 @@ def did_single_gt(

# cluster_by_entity should be False for true rc (would be redundant)
if (
not is_balanced_panel and cluster_by_entity
not is_balanced_panel and cluster_by_entity
): # panels (as rc) but entity level if

# cohort_dummy needs to be defined before masking for unbalanced panels,
Expand Down Expand Up @@ -341,26 +342,26 @@ def did_single_gt(


def get_att_gt(
group_time: list[dict],
data: DataFrame,
y_name: str,
cohort_name: str,
is_panel: bool,
is_balanced_panel: bool,
cluster_by_entity: bool,
x_covariates: list,
x_base: list,
x_delta: list,
strata_name: str = None,
weights_name: str = None,
control_group: str = "never_treated",
anticipation: int = 0,
att_function_ct: Callable = None,
n_jobs_ct: int = 1,
backend_ct: str = "loky",
progress_bar: bool = True,
sample_name: str = None,
release_workers: bool = True,
group_time: list[dict],
data: DataFrame,
y_name: str,
cohort_name: str,
is_panel: bool,
is_balanced_panel: bool,
cluster_by_entity: bool,
x_covariates: list,
x_base: list,
x_delta: list,
strata_name: str = None,
weights_name: str = None,
control_group: str = "never_treated",
anticipation: int = 0,
att_function_ct: Callable = None,
n_jobs_ct: int = 1,
backend_ct: str = "loky",
progress_bar: bool = True,
sample_name: str = None,
release_workers: bool = True,
) -> list[namedtuple]:
if control_group not in ["never_treated", "not_yet_treated"]:
raise ValueError(
Expand All @@ -380,57 +381,89 @@ def get_att_gt(
jobs = cpu_count() + 1 - n_jobs_ct if n_jobs_ct < 0 else n_jobs_ct
sample_name = f"for {sample_name} " if sample_name else ""

with tqdm_joblib(
tqdm(
disable=not progress_bar,
desc=f"Computing ATTgt {sample_name}[workers={jobs}]",
total=len(group_time),
bar_format="{desc:<30}{percentage:3.0f}%|{bar:20}{r_bar}",
)
):

# compute the ct att. order of parameters must conform to did_single_gt()
res_ntl = Parallel(n_jobs=n_jobs_ct, backend=backend_ct,)(
delayed(did_single_gt)(
data,
entities,
entity_name,
n_total,
y_name,
cohort_name,
strata_name,
weights_name,
x_covariates,
x_base,
x_delta,
anticipation,
control_group,
is_panel,
is_balanced_panel,
cluster_by_entity,
att_function_ct,
**gt,
tqdm_object = tqdm(
disable=not progress_bar,
desc=f"Computing ATTgt {sample_name}[workers={jobs}]",
total=len(group_time),
bar_format="{desc:<30}{percentage:3.0f}%|{bar:20}{r_bar}",
)

if jobs > 1: # tqdm_joblib doesn't work anymore for jobs=1
with tqdm_joblib(tqdm_object):

# compute the ct att. order of parameters must conform to did_single_gt()
# compute the ct att. order of parameters must conform to did_single_gt()
res_ntl = Parallel(n_jobs=n_jobs_ct, backend=backend_ct, )(
delayed(did_single_gt)(
data,
entities,
entity_name,
n_total,
y_name,
cohort_name,
strata_name,
weights_name,
x_covariates,
x_base,
x_delta,
anticipation,
control_group,
is_panel,
is_balanced_panel,
cluster_by_entity,
att_function_ct,
**gt,
)
for gt in group_time
)
for gt in group_time
)

if release_workers:
get_reusable_executor().shutdown(wait=True)
if release_workers:
get_reusable_executor().shutdown(wait=True)

else: # n_jobs=1

with tqdm_joblib(tqdm_object):

res_ntl = []
for gt in group_time:
res_ntl.append(
did_single_gt(
data,
entities,
entity_name,
n_total,
y_name,
cohort_name,
strata_name,
weights_name,
x_covariates,
x_base,
x_delta,
anticipation,
control_group,
is_panel,
is_balanced_panel,
cluster_by_entity,
att_function_ct,
**gt,
)
)

tqdm_object.update(1)
return res_ntl


def get_standard_errors(
ntl: list[namedtuple],
cluster_groups: np.ndarray = None,
alpha: float = 0.05,
boot_iterations: int = 0,
random_state: int = None,
backend_boot: str = "loky",
n_jobs_boot: int = -1,
progress_bar: bool = True,
sample_name: str = None,
release_workers: bool = True,
ntl: list[namedtuple],
cluster_groups: np.ndarray = None,
alpha: float = 0.05,
boot_iterations: int = 0,
random_state: int = None,
backend_boot: str = "loky",
n_jobs_boot: int = -1,
progress_bar: bool = True,
sample_name: str = None,
release_workers: bool = True,
) -> list[namedtuple]:
if boot_iterations < 0:
raise ValueError(
Expand Down Expand Up @@ -490,13 +523,13 @@ def get_standard_errors(


def get_cohort_stratum_dummies(
data: DataFrame,
entities: Index | list,
cohort_name: str,
cohort: int,
strata_name: str,
stratum: str | int | float,
repeated_cross_section: bool = False,
data: DataFrame,
entities: Index | list,
cohort_name: str,
cohort: int,
strata_name: str,
stratum: str | int | float,
repeated_cross_section: bool = False,
) -> tuple:
"""helper for did_single_gt"""
if repeated_cross_section:
Expand Down

0 comments on commit 9e6a041

Please sign in to comment.