diff --git a/differences/tools/panel_validation.py b/differences/tools/panel_validation.py index 72d0a48..36e87ea 100644 --- a/differences/tools/panel_validation.py +++ b/differences/tools/panel_validation.py @@ -133,9 +133,12 @@ def __set__(self, instance, data): ) # for each entity-event_date, - # create the start (amin) and end (amax) of the panel for the corresponding entity + # create the start (_min_time) and end (_max_time) of the panel for the corresponding entity cohort_data = cohort_data[cohort_info_list].join( - data.groupby(entity_name)[time_name].agg([np.min, np.max]) + data + .groupby(entity_name)[time_name] + .agg(['min', 'max']) + .set_axis(labels=['_min_time', '_max_time'], axis=1) ) # ----------- pre-process cohorts / data ------------------- @@ -330,7 +333,7 @@ def pre_process_treated_before( # entities whose event happened BEFORE the start of their time treated_before = cohort_data.loc[ - lambda x: x[cohort_name] <= x["amin"] + lambda x: x[cohort_name] <= x["_min_time"] ].index.unique() if len(treated_before): @@ -357,12 +360,12 @@ def pre_process_treated_after( """drops event dates that come after the end of the panel (for the entity)""" # entities whose event happened AFTER the end of their time - bool_after_end = cohort_data[cohort_name] > cohort_data["amax"] + bool_after_end = cohort_data[cohort_name] > cohort_data["_max_time"] # shouldn't this be [cohort - anticipation] ? # line 85: https://github.com/bcallaway11/did/blob/master/R/pre_process_did.R # bool_after_end = (cohort_data[cohort_name] - - # anticipation > cohort_data['amax']) + # anticipation > cohort_data['_max_time']) n_entities_after_end = np.sum(bool_after_end) # number of entities if n_entities_after_end: