Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add non-negative reconciliation heuristic for MinTraceSparse #284

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTraceSparse._get_PW_matrices': ( 'methods.html#mintracesparse._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTraceSparse.fit': ( 'methods.html#mintracesparse.fit',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.OptimalCombination': ( 'methods.html#optimalcombination',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.OptimalCombination.__init__': ( 'methods.html#optimalcombination.__init__',
Expand Down
78 changes: 69 additions & 9 deletions hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,11 +1150,6 @@ def _get_PW_matrices(
"Only the methods with diagonal W are supported as sparse operations"
)

if self.nonnegative:
raise NotImplementedError(
"Non-negative MinT is currently not implemented as sparse"
)

S = sparse.csr_matrix(S)

if self.method in res_methods and y_insample is None and y_hat_insample is None:
Expand Down Expand Up @@ -1213,7 +1208,7 @@ def get_P_action(y):
(b.size, b.size), matvec=lambda v: R @ (S @ v)
)

x_tilde, exit_code = sparse.linalg.bicgstab(A, b, atol="legacy")
x_tilde, exit_code = sparse.linalg.bicgstab(A, b)
christophertitchen marked this conversation as resolved.
Show resolved Hide resolved

return x_tilde

Expand All @@ -1224,7 +1219,72 @@ def get_P_action(y):

return P, W

# %% ../nbs/methods.ipynb 55
def fit(self,
S: sparse.csr_matrix,
y_hat: np.ndarray,
y_insample: Optional[np.ndarray] = None,
y_hat_insample: Optional[np.ndarray] = None,
sigmah: Optional[np.ndarray] = None,
intervals_method: Optional[str] = None,
num_samples: Optional[int] = None,
seed: Optional[int] = None,
tags: Optional[Dict[str, np.ndarray]] = None,
idx_bottom: Optional[np.ndarray] = None):
# Clip the base forecasts if required to align them with their use in practice.
if self.nonnegative:
self.y_hat = np.clip(y_hat, 0, None)
else:
self.y_hat = y_hat
# Get the reconciliation matrices.
self.P, self.W = self._get_PW_matrices(
S=S,
y_hat=self.y_hat,
y_insample=y_insample,
y_hat_insample=y_hat_insample,
idx_bottom=idx_bottom,
)

if self.nonnegative:
# Get the number of leaf nodes.
_, n_bottom = S.shape
# Although it is now sufficient to ensure that all of the entries in P are
# positive, as it is implemented as a linear operator for the iterative
# method to solve the sparse linear system, we need to reconcile to find
# if any of the coherent bottom level point forecasts are negative.
y_tilde = self._reconcile(
S=S, P=self.P, y_hat=self.y_hat, level=None, sampler=None
)["mean"][-n_bottom:]
# Find if any of the forecasts are negative.
if np.any(y_tilde < 0):
# Clip the negative forecasts.
y_tilde = np.clip(y_tilde, 0, None)
# Force non-negative coherence by overwriting the base forecasts with
# the aggregated, clipped bottom level forecasts.
self.y_hat = S @ y_tilde
# Overwrite the attributes for the P and W matrices with those for
# bottom-up reconciliation to force projection onto the non-negative
# coherent subspace.
self.P, self.W = BottomUpSparse()._get_PW_matrices(S=S, idx_bottom=None)

# Get the sampler for probabilistic reconciliation.
self.sampler = self._get_sampler(
S=S,
P=self.P,
W=self.W,
y_hat=self.y_hat,
y_insample=y_insample,
y_hat_insample=y_hat_insample,
sigmah=sigmah,
intervals_method=intervals_method,
num_samples=num_samples,
seed=seed,
tags=tags,
)
# Set the instance as fitted.
self.fitted = True
return self

# %% ../nbs/methods.ipynb 56
class OptimalCombination(MinTrace):
"""Optimal Combination Reconciliation Class.

Expand Down Expand Up @@ -1258,7 +1318,7 @@ def __init__(self,
super().__init__(method=method, nonnegative=nonnegative, num_threads=num_threads)
self.insample = False

# %% ../nbs/methods.ipynb 64
# %% ../nbs/methods.ipynb 65
@njit
def lasso(X: np.ndarray, y: np.ndarray,
lambda_reg: float, max_iters: int = 1_000,
Expand Down Expand Up @@ -1290,7 +1350,7 @@ def lasso(X: np.ndarray, y: np.ndarray,
#print(it)
return beta

# %% ../nbs/methods.ipynb 65
# %% ../nbs/methods.ipynb 66
class ERM(HReconciler):
"""Optimal Combination Reconciliation Class.

Expand Down
96 changes: 89 additions & 7 deletions nbs/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1762,11 +1762,6 @@
" \"Only the methods with diagonal W are supported as sparse operations\"\n",
" )\n",
"\n",
" if self.nonnegative:\n",
" raise NotImplementedError(\n",
" \"Non-negative MinT is currently not implemented as sparse\"\n",
" )\n",
"\n",
" S = sparse.csr_matrix(S)\n",
"\n",
" if self.method in res_methods and y_insample is None and y_hat_insample is None:\n",
Expand Down Expand Up @@ -1825,7 +1820,7 @@
" (b.size, b.size), matvec=lambda v: R @ (S @ v)\n",
" )\n",
"\n",
" x_tilde, exit_code = sparse.linalg.bicgstab(A, b, atol=\"legacy\")\n",
" x_tilde, exit_code = sparse.linalg.bicgstab(A, b)\n",
"\n",
" return x_tilde\n",
"\n",
Expand All @@ -1834,7 +1829,72 @@
" )\n",
" W = sparse.spdiags(W_diag, 0, W_diag.size, W_diag.size)\n",
"\n",
" return P, W"
" return P, W\n",
"\n",
" def fit(self,\n",
" S: sparse.csr_matrix,\n",
" y_hat: np.ndarray,\n",
" y_insample: Optional[np.ndarray] = None,\n",
" y_hat_insample: Optional[np.ndarray] = None,\n",
" sigmah: Optional[np.ndarray] = None,\n",
" intervals_method: Optional[str] = None,\n",
" num_samples: Optional[int] = None,\n",
" seed: Optional[int] = None, \n",
" tags: Optional[Dict[str, np.ndarray]] = None,\n",
" idx_bottom: Optional[np.ndarray] = None):\n",
" # Clip the base forecasts if required to align them with their use in practice.\n",
" if self.nonnegative:\n",
" self.y_hat = np.clip(y_hat, 0, None)\n",
" else:\n",
" self.y_hat = y_hat\n",
" # Get the reconciliation matrices.\n",
" self.P, self.W = self._get_PW_matrices(\n",
" S=S, \n",
" y_hat=self.y_hat, \n",
" y_insample=y_insample, \n",
" y_hat_insample=y_hat_insample, \n",
" idx_bottom=idx_bottom,\n",
" )\n",
"\n",
" if self.nonnegative:\n",
" # Get the number of leaf nodes.\n",
" _, n_bottom = S.shape\n",
" # Although it is now sufficient to ensure that all of the entries in P are \n",
" # positive, as it is implemented as a linear operator for the iterative \n",
" # method to solve the sparse linear system, we need to reconcile to find \n",
" # if any of the coherent bottom level point forecasts are negative.\n",
" y_tilde = self._reconcile(\n",
" S=S, P=self.P, y_hat=self.y_hat, level=None, sampler=None\n",
" )[\"mean\"][-n_bottom:]\n",
" # Find if any of the forecasts are negative.\n",
" if np.any(y_tilde < 0):\n",
" # Clip the negative forecasts.\n",
" y_tilde = np.clip(y_tilde, 0, None)\n",
" # Force non-negative coherence by overwriting the base forecasts with \n",
" # the aggregated, clipped bottom level forecasts.\n",
" self.y_hat = S @ y_tilde\n",
" # Overwrite the attributes for the P and W matrices with those for \n",
" # bottom-up reconciliation to force projection onto the non-negative \n",
" # coherent subspace.\n",
" self.P, self.W = BottomUpSparse()._get_PW_matrices(S=S, idx_bottom=None) \n",
"\n",
" # Get the sampler for probabilistic reconciliation.\n",
" self.sampler = self._get_sampler(\n",
" S=S,\n",
" P=self.P,\n",
" W=self.W,\n",
" y_hat=self.y_hat,\n",
" y_insample=y_insample,\n",
" y_hat_insample=y_hat_insample,\n",
" sigmah=sigmah,\n",
" intervals_method=intervals_method,\n",
" num_samples=num_samples,\n",
" seed=seed,\n",
" tags=tags,\n",
" )\n",
" # Set the instance as fitted.\n",
" self.fitted = True\n",
" return self"
]
},
{
Expand Down Expand Up @@ -1977,6 +2037,28 @@
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"for method in [\"ols\", \"wls_struct\"]:\n",
" for nonnegative in [False, True]:\n",
" cls_min_trace = MinTraceSparse(method=method, nonnegative=nonnegative)\n",
" test_close(\n",
" cls_min_trace(\n",
" S=S,\n",
" y_hat=S @ y_hat_bottom,\n",
" y_insample=S @ y_bottom,\n",
" y_hat_insample=S @ y_hat_bottom_insample,\n",
" idx_bottom=idx_bottom if nonnegative else None,\n",
" )[\"mean\"],\n",
" S @ y_hat_bottom,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down