Skip to content

Commit

Permalink
Merge pull request #338 from pymc-labs/multi-cell-geolift
Browse files Browse the repository at this point in the history
Add example analysis of multiple geo lift test analysis
  • Loading branch information
drbenvincent authored Jun 21, 2024
2 parents e90bab9 + f89c53b commit 67181c6
Show file tree
Hide file tree
Showing 8 changed files with 1,716 additions and 14 deletions.
1 change: 1 addition & 0 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"sc": {"filename": "synthetic_control.csv"},
"anova1": {"filename": "ancova_generated.csv"},
"geolift1": {"filename": "geolift1.csv"},
"geolift_multi_cell": {"filename": "geolift_multi_cell.csv"},
"risk": {"filename": "AJR2001.csv"},
"nhefs": {"filename": "nhefs.csv"},
"schoolReturns": {"filename": "schoolingReturns.csv"},
Expand Down
209 changes: 209 additions & 0 deletions causalpy/data/geolift_multi_cell.csv

Large diffs are not rendered by default.

90 changes: 79 additions & 11 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_synthetic_control_data(
Generates data for synthetic control example.
:param N:
Number fo data points
Number of data points
:param treatment_time:
Index where treatment begins in the generated dataframe
:param grw_mu:
Expand Down Expand Up @@ -324,15 +324,6 @@ def generate_geolift_data():
treatment_time = pd.to_datetime("2022-01-01")
causal_impact = 0.2

def create_series(n=52, amplitude=1, length_scale=2):
"""
Returns numpy tile with generated seasonality data repeated over
multiple years
"""
return np.tile(
generate_seasonality(n=n, amplitude=amplitude, length_scale=2) + 3, n_years
)

time = pd.date_range(start="2019-01-01", periods=52 * n_years, freq="W")

untreated = [
Expand All @@ -345,7 +336,12 @@ def create_series(n=52, amplitude=1, length_scale=2):
]

df = (
pd.DataFrame({country: create_series() for country in untreated})
pd.DataFrame(
{
country: create_series(n_years=n_years, intercept=3)
for country in untreated
}
)
.assign(time=time)
.set_index("time")
)
Expand All @@ -360,6 +356,67 @@ def create_series(n=52, amplitude=1, length_scale=2):

# add treatment effect
df["Denmark"] += np.where(df.index < treatment_time, 0, causal_impact)

# ensure we never see any negative sales
df = df.clip(lower=0)

return df


def generate_multicell_geolift_data():
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
countries. The treated unit `Denmark` is a weighted combination of the untreated
units. We additionally specify a treatment effect which takes effect after the
`treatment_time`. The timeseries data is observed at weekly resolution and has
annual seasonality, with this seasonality being a drawn from a Gaussian Process with
a periodic kernel."""
n_years = 4
treatment_time = pd.to_datetime("2022-01-01")
causal_impact = 0.2
time = pd.date_range(start="2019-01-01", periods=52 * n_years, freq="W")

untreated = [
"u1",
"u2",
"u3",
"u4",
"u5",
"u6",
"u7",
"u8",
"u9",
"u10",
"u11",
"u12",
]

df = (
pd.DataFrame(
{
country: create_series(n_years=n_years, intercept=3)
for country in untreated
}
)
.assign(time=time)
.set_index("time")
)

treated = ["t1", "t2", "t3", "t4"]

for treated_geo in treated:
# create treated unit as a weighted sum of the untreated units
weights = np.random.dirichlet(np.ones(len(untreated)), size=1)[0]
df[treated_geo] = np.dot(df[untreated].values, weights)
# add treatment effect
df[treated_geo] += np.where(df.index < treatment_time, 0, causal_impact)

# add observation noise to all geos
for col in untreated + treated:
df[col] += np.random.normal(size=len(df), scale=0.1)

# ensure we never see any negative sales
df = df.clip(lower=0)

return df


Expand Down Expand Up @@ -387,3 +444,14 @@ def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
return amplitude**2 * np.exp(
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
)


def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
"""
Returns numpy tile with generated seasonality data repeated over
multiple years
"""
return np.tile(
generate_seasonality(n=n, amplitude=amplitude, length_scale=2) + intercept,
n_years,
)
41 changes: 41 additions & 0 deletions causalpy/tests/test_synthetic_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for the simulated data functions
"""

import numpy as np
import pandas as pd


def test_generate_multicell_geolift_data():
"""
Test the generate_multicell_geolift_data function.
"""
from causalpy.data.simulate_data import generate_multicell_geolift_data

df = generate_multicell_geolift_data()
assert type(df) == pd.DataFrame
assert np.all(df >= 0), "Found negative values in dataset"


def test_generate_geolift_data():
"""
Test the generate_geolift_data function.
"""
from causalpy.data.simulate_data import generate_geolift_data

df = generate_geolift_data()
assert type(df) == pd.DataFrame
assert np.all(df >= 0), "Found negative values in dataset"
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ Synthetic Control
notebooks/sc_pymc.ipynb
notebooks/sc_skl.ipynb
notebooks/sc_pymc_brexit.ipynb

Geographical lift testing
=========================

.. toctree::
:titlesonly:

notebooks/geolift1.ipynb
notebooks/multi_cell_geolift.ipynb


Difference in Differences
Expand Down
Loading

0 comments on commit 67181c6

Please sign in to comment.