Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a temporal validation period to synthetic control and interrupted time series experiments #367

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
77 changes: 61 additions & 16 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
A pandas dataframe
:param treatment_time:
The time when treatment occured, should be in reference to the data index
:param validation_time:
Optional time to split the data into training and validation data sets
:param formula:
A statistical model formula
:param model:
Expand Down Expand Up @@ -160,6 +162,7 @@
>>> result.summary(round_to=1)
==================================Pre-Post Fit==================================
Formula: actual ~ 0 + a + g
Pre-intervention Bayesian $R^2$: 0.9 (std = 0.01)
Model coefficients:
a 0.6, 94% HDI [0.6, 0.6]
g 0.4, 94% HDI [0.4, 0.4]
Expand All @@ -171,17 +174,30 @@
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
formula: str,
validation_time=None,
model=None,
**kwargs,
) -> None:
super().__init__(model=model, **kwargs)
self._input_validation(data, treatment_time)
self.treatment_time = treatment_time
self.validation_time = validation_time
# validate arguments
if self.validation_time is not None:
# check that validation time is less than treatment time
if self.validation_time >= self.treatment_time:
raise ValueError(
"Validation time must be less than the treatment time."
)
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]
if self.validation_time is None:
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]
else:
self.datapre = data[data.index < self.validation_time]
self.datapost = data[data.index >= self.validation_time]

self.formula = formula

Expand All @@ -203,8 +219,22 @@
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
if self.validation_time is None:
# We just have pre and post data, no validation data. So we can score the pre intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could replace on validation score by bayesian tail prob instead of R2? So, the interpretation here is about how much the real mean during the validation diverge from the posterior mean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Just looking into it so that I get it right - I just need the high level algorithm because I've not heard that much about it. Google search for "bayesian tail probability" shows very few hits. Is this a widely used approach? Doesn't matter if not, as long as it does what we want :)

else:
# Score on the training data - before the validation time
self.datatrain = data[data.index < self.validation_time]
y, X = dmatrices(formula, self.datatrain)
self.score = self.model.score(X=X, y=y)
# Score on the validation data - after the validation time but
# before the treatment time
self.datavalidate = data[
(data.index >= self.validation_time)
& (data.index < self.treatment_time)
]
y, X = dmatrices(formula, self.datavalidate)
self.score_validation = self.model.score(X=X, y=y)

# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)
Expand Down Expand Up @@ -275,13 +305,6 @@
handles.append(h)
labels.append("Causal impact")

ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
(std = {round_num(self.score.r2_std, round_to)})
"""
)

# MIDDLE PLOT -----------------------------------------------
plot_xY(
self.datapre.index,
Expand All @@ -303,10 +326,10 @@
alpha=0.25,
label="Causal impact",
)
ax[1].set(title="Causal Impact")
ax[1].set(ylabel="Causal Impact")

Check warning on line 329 in causalpy/pymc_experiments.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_experiments.py#L329

Added line #L329 was not covered by tests

# BOTTOM PLOT -----------------------------------------------
ax[2].set(title="Cumulative Causal Impact")
ax[2].set(ylabel="Cumulative Causal Impact")

Check warning on line 332 in causalpy/pymc_experiments.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_experiments.py#L332

Added line #L332 was not covered by tests
plot_xY(
self.datapost.index,
self.post_impact_cumulative,
Expand All @@ -319,10 +342,17 @@
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
ls="--",
# lw=3,
color="k",
)
if self.validation_time is not None:
ax[i].axvline(

Check warning on line 350 in causalpy/pymc_experiments.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_experiments.py#L349-L350

Added lines #L349 - L350 were not covered by tests
x=self.validation_time,
ls="--",
# lw=3,
color="k",
)

ax[0].legend(
handles=(h_tuple for h_tuple in handles),
Expand All @@ -342,6 +372,17 @@

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
# print goodness of fit scores
if self.validation_time is None:
print(
f"Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
)
else:
print(
f"Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})\n"
f"Validation Bayesian $R^2$: {round_num(self.score_validation.r2, round_to)} (std = {round_num(self.score_validation.r2_std, round_to)})"
)
# print coefficients
self.print_coefficients(round_to)


Expand All @@ -355,6 +396,8 @@
The time when treatment occured, should be in reference to the data index
:param formula:
A statistical model formula
:param validation_time:
Optional time to split the data into training and validation data sets
:param model:
A PyMC model

Expand Down Expand Up @@ -394,6 +437,8 @@
The time when treatment occured, should be in reference to the data index
:param formula:
A statistical model formula
:param validation_time:
Optional time to split the data into training and validation data sets
:param model:
A PyMC model

Expand Down
52 changes: 47 additions & 5 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,15 @@ def test_rkink_bandwidth():
result.summary()


@pytest.mark.parametrize("validation_time", [None, pd.to_datetime("2015-01-01")])
@pytest.mark.integration
def test_its():
def test_its(validation_time):
"""
Test Interrupted Time-Series experiment.

Loads data and checks:
1. data is a dataframe
2. pymc_experiments.SyntheticControl returns correct type
2. pymc_experiments.InterruptedTimeSeries returns correct type
3. the correct number of MCMC chains exists in the posterior inference data
4. the correct number of MCMC draws exists in the posterior inference data
"""
Expand All @@ -334,19 +335,41 @@ def test_its():
.set_index("date")
)
treatment_time = pd.to_datetime("2017-01-01")
result = cp.pymc_experiments.SyntheticControl(
result = cp.pymc_experiments.InterruptedTimeSeries(
df,
treatment_time,
validation_time=validation_time,
formula="y ~ 1 + t + C(month)",
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)
assert isinstance(df, pd.DataFrame)
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
assert isinstance(result, cp.pymc_experiments.InterruptedTimeSeries)
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
result.summary()


def test_its_with_invalid_validation_time():
"""
Test that we get a ValueError when validation_time is greater than validation_time.
"""
df = (
cp.load_data("its")
.assign(date=lambda x: pd.to_datetime(x["date"]))
.set_index("date")
)
treatment_time = pd.to_datetime("2017-01-01")
validation_time = pd.to_datetime("2018-01-01")
with pytest.raises(ValueError):
_ = cp.pymc_experiments.InterruptedTimeSeries(
df,
treatment_time,
validation_time=validation_time,
formula="y ~ 1 + t + C(month)",
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)


@pytest.mark.integration
def test_its_covid():
"""
Expand Down Expand Up @@ -379,7 +402,8 @@ def test_its_covid():


@pytest.mark.integration
def test_sc():
@pytest.mark.parametrize("validation_time", [None, 50])
def test_sc(validation_time):
"""
Test Synthetic Control experiment.

Expand All @@ -395,6 +419,7 @@ def test_sc():
result = cp.pymc_experiments.SyntheticControl(
df,
treatment_time,
validation_time=validation_time,
formula="actual ~ 0 + a + b + c + d + e + f + g",
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
)
Expand All @@ -405,6 +430,23 @@ def test_sc():
result.summary()


def test_sc_with_invalid_validation_time():
"""
Test that we get a ValueError when validation_time is greater than validation_time.
"""
df = cp.load_data("sc")
treatment_time = 70
validation_time = 80
with pytest.raises(ValueError):
_ = cp.pymc_experiments.SyntheticControl(
df,
treatment_time,
validation_time=validation_time,
formula="actual ~ 0 + a + b + c + d + e + f + g",
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
)


@pytest.mark.integration
def test_sc_brexit():
"""
Expand Down
Binary file modified docs/source/_static/classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
194 changes: 178 additions & 16 deletions docs/source/notebooks/its_pymc.ipynb

Large diffs are not rendered by default.

213 changes: 190 additions & 23 deletions docs/source/notebooks/sc_pymc.ipynb

Large diffs are not rendered by default.

Loading