Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added 2 of the three algorithms for the top down approach+tests #179

Merged
merged 8 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions examples/hierarchical_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hierarchical model\n",
"This exemple shows how the hierarchical model can be used"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"%matplotlib inline \n",
"\n",
"from gtime.hierarchical import HierarchicalMiddleOut\n",
"from gtime.hierarchical import HierarchicalTopDown\n",
"from gtime.hierarchical import HierarchicalBottomUp\n",
"import pandas._testing as testing\n",
"from gtime.time_series_models import AR"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"testing.N, testing.K = 20, 1\n",
"\n",
"data1 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data2 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data3 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data4 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data5 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data6 = testing.makeTimeDataFrame(freq=\"s\")\n",
"data = {'data1': data1, 'data2': data2, 'data3' : data3, 'data4' : data4, 'data5' : data5, 'data6' : data6}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"tree_adj = {'data1' : ['data2','data3'], 'data2': ['data4', 'data5'], 'data3':['data6'], 'data4':[], 'data5':[], 'data6':[]} "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Pass n_jobs=None as keyword args. From version 0.25 passing these as positional arguments will result in an error\n"
]
}
],
"source": [
"stat_model = AR(p=2, horizon=3)\n",
"middle_out_model = HierarchicalMiddleOut(model=stat_model, hierarchy_tree=tree_adj, method='tdsga', level=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AR(horizon=3, p=2)\n"
]
}
],
"source": [
"fitting_middle_out = middle_out_model.fit(data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'data1': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 -0.158229 -0.108937 0.013279\n",
" 2000-01-01 00:00:02 -0.119528 -0.059880 -0.107075\n",
" 2000-01-01 00:00:03 -0.045503 0.021338 0.047517\n",
" 2000-01-01 00:00:04 0.012864 0.096452 -0.168450\n",
" 2000-01-01 00:00:05 -0.151905 -0.096779 -0.132723\n",
" 2000-01-01 00:00:06 0.069999 0.150488 0.214797\n",
" 2000-01-01 00:00:07 0.001033 0.093826 -0.508995\n",
" 2000-01-01 00:00:08 -0.568503 -0.583920 -0.085755\n",
" 2000-01-01 00:00:09 -0.457690 -0.464454 0.210157\n",
" 2000-01-01 00:00:10 -0.122242 -0.067394 0.025629\n",
" 2000-01-01 00:00:11 -0.118071 -0.056340 -0.163261\n",
" 2000-01-01 00:00:12 -0.191902 -0.147248 -0.015386\n",
" 2000-01-01 00:00:13 -0.294175 -0.260952 -0.183406\n",
" 2000-01-01 00:00:14 -0.181954 -0.143633 0.228090\n",
" 2000-01-01 00:00:15 -0.045477 0.031479 -0.260897\n",
" 2000-01-01 00:00:16 -0.226103 -0.186952 -0.020310\n",
" 2000-01-01 00:00:17 -0.601745 -0.612361 -0.400169\n",
" 2000-01-01 00:00:18 -0.775542 -0.839130 0.337599\n",
" 2000-01-01 00:00:19 -0.330639 -0.312211 0.083676,\n",
" 'data2': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 0.270159 0.185998 -0.022672\n",
" 2000-01-01 00:00:02 0.204081 0.102239 0.182818\n",
" 2000-01-01 00:00:03 0.077691 -0.036433 -0.081129\n",
" 2000-01-01 00:00:04 -0.021964 -0.164682 0.287609\n",
" 2000-01-01 00:00:05 0.259361 0.165239 0.226610\n",
" 2000-01-01 00:00:06 -0.119516 -0.256941 -0.366741\n",
" 2000-01-01 00:00:07 -0.001764 -0.160198 0.869053\n",
" 2000-01-01 00:00:08 0.970656 0.996979 0.146418\n",
" 2000-01-01 00:00:09 0.781456 0.793004 -0.358821\n",
" 2000-01-01 00:00:10 0.208715 0.115068 -0.043758\n",
" 2000-01-01 00:00:11 0.201593 0.096194 0.278750\n",
" 2000-01-01 00:00:12 0.327652 0.251409 0.026269\n",
" 2000-01-01 00:00:13 0.502271 0.445547 0.313145\n",
" 2000-01-01 00:00:14 0.310666 0.245238 -0.389439\n",
" 2000-01-01 00:00:15 0.077647 -0.053747 0.445453\n",
" 2000-01-01 00:00:16 0.386045 0.319200 0.034676\n",
" 2000-01-01 00:00:17 1.027413 1.045538 0.683245\n",
" 2000-01-01 00:00:18 1.324152 1.432722 -0.576413\n",
" 2000-01-01 00:00:19 0.564529 0.533066 -0.142867,\n",
" 'data4': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 -0.290519 -0.200016 0.024381\n",
" 2000-01-01 00:00:02 -0.219461 -0.109944 -0.196596\n",
" 2000-01-01 00:00:03 -0.083546 0.039178 0.087243\n",
" 2000-01-01 00:00:04 0.023619 0.177093 -0.309285\n",
" 2000-01-01 00:00:05 -0.278907 -0.177692 -0.243688\n",
" 2000-01-01 00:00:06 0.128523 0.276306 0.394381\n",
" 2000-01-01 00:00:07 0.001897 0.172271 -0.934548\n",
" 2000-01-01 00:00:08 -1.043809 -1.072116 -0.157453\n",
" 2000-01-01 00:00:09 -0.840350 -0.852768 0.385863\n",
" 2000-01-01 00:00:10 -0.224444 -0.123740 0.047056\n",
" 2000-01-01 00:00:11 -0.216786 -0.103444 -0.299757\n",
" 2000-01-01 00:00:12 -0.352345 -0.270357 -0.028249\n",
" 2000-01-01 00:00:13 -0.540124 -0.479125 -0.336745\n",
" 2000-01-01 00:00:14 -0.334079 -0.263720 0.418789\n",
" 2000-01-01 00:00:15 -0.083499 0.057797 -0.479024\n",
" 2000-01-01 00:00:16 -0.415140 -0.343256 -0.037290\n",
" 2000-01-01 00:00:17 -1.104843 -1.124335 -0.734737\n",
" 2000-01-01 00:00:18 -1.423946 -1.540698 0.619854\n",
" 2000-01-01 00:00:19 -0.607074 -0.573241 0.153634,\n",
" 'data5': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 -0.391357 -0.269440 0.032843\n",
" 2000-01-01 00:00:02 -0.295635 -0.148105 -0.264833\n",
" 2000-01-01 00:00:03 -0.112545 0.052777 0.117525\n",
" 2000-01-01 00:00:04 0.031817 0.238561 -0.416636\n",
" 2000-01-01 00:00:05 -0.375714 -0.239368 -0.328271\n",
" 2000-01-01 00:00:06 0.173133 0.372210 0.531268\n",
" 2000-01-01 00:00:07 0.002555 0.232065 -1.258925\n",
" 2000-01-01 00:00:08 -1.406109 -1.444241 -0.212103\n",
" 2000-01-01 00:00:09 -1.132030 -1.148759 0.519794\n",
" 2000-01-01 00:00:10 -0.302348 -0.166689 0.063389\n",
" 2000-01-01 00:00:11 -0.292032 -0.139349 -0.403801\n",
" 2000-01-01 00:00:12 -0.474642 -0.364196 -0.038054\n",
" 2000-01-01 00:00:13 -0.727598 -0.645427 -0.453627\n",
" 2000-01-01 00:00:14 -0.450036 -0.355256 0.564148\n",
" 2000-01-01 00:00:15 -0.112481 0.077858 -0.645291\n",
" 2000-01-01 00:00:16 -0.559232 -0.462398 -0.050233\n",
" 2000-01-01 00:00:17 -1.488328 -1.514584 -0.989760\n",
" 2000-01-01 00:00:18 -1.918189 -2.075466 0.835002\n",
" 2000-01-01 00:00:19 -0.817786 -0.772209 0.206960,\n",
" 'data3': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 -0.181452 -0.124926 0.015228\n",
" 2000-01-01 00:00:02 -0.137071 -0.068669 -0.122790\n",
" 2000-01-01 00:00:03 -0.052181 0.024470 0.054490\n",
" 2000-01-01 00:00:04 0.014752 0.110609 -0.193173\n",
" 2000-01-01 00:00:05 -0.174200 -0.110983 -0.152203\n",
" 2000-01-01 00:00:06 0.080273 0.172575 0.246322\n",
" 2000-01-01 00:00:07 0.001185 0.107597 -0.583699\n",
" 2000-01-01 00:00:08 -0.651941 -0.669621 -0.098342\n",
" 2000-01-01 00:00:09 -0.524865 -0.532621 0.241002\n",
" 2000-01-01 00:00:10 -0.140183 -0.077285 0.029390\n",
" 2000-01-01 00:00:11 -0.135400 -0.064609 -0.187222\n",
" 2000-01-01 00:00:12 -0.220067 -0.168859 -0.017644\n",
" 2000-01-01 00:00:13 -0.337350 -0.299252 -0.210324\n",
" 2000-01-01 00:00:14 -0.208659 -0.164714 0.261567\n",
" 2000-01-01 00:00:15 -0.052152 0.036099 -0.299188\n",
" 2000-01-01 00:00:16 -0.259287 -0.214390 -0.023290\n",
" 2000-01-01 00:00:17 -0.690062 -0.702236 -0.458901\n",
" 2000-01-01 00:00:18 -0.889367 -0.962288 0.387148\n",
" 2000-01-01 00:00:19 -0.379166 -0.358034 0.095957,\n",
" 'data6': y_1 y_2 y_3\n",
" 2000-01-01 00:00:01 -24.923866 -17.159492 2.091650\n",
" 2000-01-01 00:00:02 -18.827772 -9.432186 -16.866128\n",
" 2000-01-01 00:00:03 -7.167503 3.361151 7.484687\n",
" 2000-01-01 00:00:04 2.026274 15.192954 -26.533799\n",
" 2000-01-01 00:00:05 -23.927666 -15.244357 -20.906226\n",
" 2000-01-01 00:00:06 11.026122 23.704474 33.834240\n",
" 2000-01-01 00:00:07 0.162740 14.779270 -80.175653\n",
" 2000-01-01 00:00:08 -89.549225 -91.977685 -13.507979\n",
" 2000-01-01 00:00:09 -72.094285 -73.159638 33.103496\n",
" 2000-01-01 00:00:10 -19.255264 -10.615743 4.036978\n",
" 2000-01-01 00:00:11 -18.598285 -8.874552 -25.716433\n",
" 2000-01-01 00:00:12 -30.227951 -23.194102 -2.423510\n",
" 2000-01-01 00:00:13 -46.337691 -41.104553 -28.889641\n",
" 2000-01-01 00:00:14 -28.660901 -22.624761 35.928214\n",
" 2000-01-01 00:00:15 -7.163456 4.958464 -41.095860\n",
" 2000-01-01 00:00:16 -35.615159 -29.448193 -3.199118\n",
" 2000-01-01 00:00:17 -94.785394 -96.457568 -63.033679\n",
" 2000-01-01 00:00:18 -122.161436 -132.177754 53.177769\n",
" 2000-01-01 00:00:19 -52.081397 -49.178782 13.180422}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fitting_middle_out.predict(data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 1 addition & 1 deletion examples/simple_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions gtime/causality/granger_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ def fit(self, data: pd.DataFrame):
shifts = data.copy()
x_columns, y_columns = [], []
for i in range(1, self.max_shift + 1):
shifts[f"x_shift_{i}"] = data[self.x_col].shift(i)
shifts[f"y_shift_{i-1}"] = data[self.target_col].shift(i)
shifts[f"x_shift_{i}"] = data[self.target_col].shift(i)
shifts[f"y_shift_{i-1}"] = data[self.x_col].shift(i)
x_columns.append(f"x_shift_{i}")
y_columns.append(f"y_shift_{i-1}")
shifts.drop([self.x_col, self.target_col], axis="columns", inplace=True)
shifts.drop([self.target_col, self.x_col], axis="columns", inplace=True)
shifts = shifts.dropna()

data_single = shifts[x_columns].copy()
Expand Down
10 changes: 5 additions & 5 deletions gtime/causality/tests/test_granger_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
@pytest.mark.parametrize(
"test_input, expected",
[
(["ssr_f"], 0.8420421667509344),
(["ssr_chi2"], 0.8327660223526767),
(["likelihood_chi2"], 0.8341270186135072),
(["zero_f"], 0.8420421667508992),
(["ssr_f"], 0.93058225),
(["ssr_chi2"], 0.92597228),
(["likelihood_chi2"], 0.92651128),
(["zero_f"], 0.93058225),
],
)
def test_granger_pvalues_ssr_f(test_input, expected):
Expand All @@ -25,7 +25,7 @@ def test_granger_pvalues_ssr_f(test_input, expected):

data = testing.makeTimeDataFrame(freq="s", nper=1000)
granger = (
GrangerCausality(target_col="B", x_col="A", max_shift=10, statistics=test_input)
GrangerCausality(target_col="A", x_col="B", max_shift=10, statistics=test_input)
.fit(data)
.results_[0]
)
Expand Down
6 changes: 6 additions & 0 deletions gtime/hierarchical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

from .base import HierarchicalBase
from .naive import HierarchicalNaive
from .bottom_up import HierarchicalBottomUp
from .top_down import HierarchicalTopDown
from .middle_out import HierarchicalMiddleOut

__all__ = [
"HierarchicalBase",
"HierarchicalNaive",
"HierarchicalBottomUp",
"HierarchicalTopDown",
"HierarchicalMiddleOut"
]
Loading