Skip to content

Commit

Permalink
cast string entities to integers
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardodionisi committed Dec 9, 2023
1 parent 9e6a041 commit bc6688c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/attgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

att_gt = ATTgt(data=panel_data, cohort_name='cohort')

att_gt.fit(formula='y ~ x0')
att_gt.fit(formula='y ~ x0', n_jobs=-1)

att_gt.aggregate('time')

Expand Down
47 changes: 34 additions & 13 deletions src/differences/tools/panel_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pandas import DataFrame
from pandas.api.types import (is_datetime64_dtype, is_integer_dtype,
is_numeric_dtype)
from pandas.core.dtypes.common import is_string_dtype

from ..tools.panel_utility import is_panel_balanced, map_dt_series_to_int

Expand Down Expand Up @@ -36,6 +37,9 @@ def __set__(self, instance, data):
# operations on data will mostly be inplace given the class is
# providing an option copy_data

# convert entities to integer if they are strings
data = replace_string_index_with_category_code(data)

# -------------- make some columns names reserved --------------

# data should not already contain columns with the following names,
Expand Down Expand Up @@ -117,7 +121,7 @@ def __set__(self, instance, data):

# if there are no nas and there is a cohort of 0 but no 0 times
if not cohort_data_has_nas and (
min(data[time_name]) != min(cohort_data[cohort_name]) == 0
min(data[time_name]) != min(cohort_data[cohort_name]) == 0
):
raise ValueError(
f"{cohort_name} should not contain 0s for never treated groups, "
Expand Down Expand Up @@ -324,10 +328,10 @@ def is_single_event(data: DataFrame, event_dummy_name: str):


def pre_process_treated_before(
cohort_data: DataFrame,
cohort_name: str,
data: DataFrame,
copy_data: bool = True, # just to issue a warning
cohort_data: DataFrame,
cohort_name: str,
data: DataFrame,
copy_data: bool = True, # just to issue a warning
):
"""drops always treated entities"""

Expand Down Expand Up @@ -355,7 +359,7 @@ def pre_process_treated_before(


def pre_process_treated_after(
cohort_data: DataFrame, cohort_name: str, anticipation: int = 0
cohort_data: DataFrame, cohort_name: str, anticipation: int = 0
):
"""drops event dates that come after the end of the panel (for the entity)"""

Expand All @@ -382,12 +386,12 @@ def pre_process_treated_after(


def pre_process_no_never_treated(
cohort_data: DataFrame,
cohort_name: str,
data: DataFrame,
time_name: str,
anticipation: int,
copy_data: bool = True, # just to raise a warning
cohort_data: DataFrame,
cohort_name: str,
data: DataFrame,
time_name: str,
anticipation: int,
copy_data: bool = True, # just to raise a warning
):
"""makes the last cohort the comparison group, if no never treated"""

Expand Down Expand Up @@ -421,7 +425,7 @@ def pre_process_no_never_treated(


def preprocess_cohort_data(
data: DataFrame, cohort_data: DataFrame, cohort_name: str, intensity_name: str
data: DataFrame, cohort_data: DataFrame, cohort_name: str, intensity_name: str
):
"""generate valid cohort_data
Expand Down Expand Up @@ -460,3 +464,20 @@ def preprocess_cohort_data(
cohort_data.set_index([entity_name], inplace=True)

return cohort_data


def replace_string_index_with_category_code(data: pd.DataFrame):
# print("converting to integer")

entity_name, time_name = data.index.names

if is_string_dtype(data.index.get_level_values(0)):
# convert the first index level to category
new_index = data.index.get_level_values(0).astype('category')

# use the category codes as the new index, keep the second level as-is
data.index = pd.MultiIndex.from_arrays([
new_index.codes,
data.index.get_level_values(1)
], names=[entity_name, time_name])
return data

0 comments on commit bc6688c

Please sign in to comment.