Skip to content

Commit

Permalink
fixed QRF quantile format
Browse files Browse the repository at this point in the history
  • Loading branch information
vascomedici committed Sep 9, 2024
1 parent 0a7ba4a commit a1e2d74
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pyforecaster/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def anti_transform(self, x, y_hat):
def _predict_quantiles(self, x, **kwargs):
pass

def quantiles_to_df(self, q_hat:np.ndarray, index):
def quantiles_to_df(self, q_hat:np.ndarray, index, q_vect=None):
level_0_labels = self.target_cols
level_1_labels = self.q_vect
level_1_labels = self.q_vect if q_vect is None else q_vect
q_hat = np.swapaxes(q_hat, 1, 2)
q_hat = np.reshape(q_hat, (q_hat.shape[0], q_hat.shape[1] * q_hat.shape[2]))
q_hat = pd.DataFrame(q_hat, index=index, columns=pd.MultiIndex.from_product([level_0_labels, level_1_labels]))
Expand Down
12 changes: 8 additions & 4 deletions pyforecaster/forecasting_models/randomforests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, n_estimators=100, q_vect=None, val_ratio=None, nodes_at_step=
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.default_quantiles = q_vect
self.criterion = criterion
self.ccp_alpha = ccp_alpha
self.parallel = parallel
Expand All @@ -81,7 +80,7 @@ def __init__(self, n_estimators=100, q_vect=None, val_ratio=None, nodes_at_step=
"max_features": max_features ,
"max_leaf_nodes": max_leaf_nodes,
"min_impurity_decrease": min_impurity_decrease,
"default_quantiles":q_vect,
"default_quantiles":q_vect if q_vect is not None else 'mean',
"criterion":criterion,
"ccp_alpha":ccp_alpha
}
Expand Down Expand Up @@ -159,7 +158,7 @@ def predict(self, x, **kwargs):
if len(preds.shape) == 2:
preds = np.expand_dims(preds, 0)
preds = np.swapaxes(preds, 1, 2)
preds = self.quantiles_to_df(preds, index=x.index)
preds = self.quantiles_to_df(preds, index=x.index, q_vect=kwargs['quantiles'])
else:
preds = pd.DataFrame(np.atleast_2d(np.squeeze(preds)), index=x.index, columns=self.target_cols)
y_hat = self.anti_transform(x, preds)
Expand All @@ -171,12 +170,17 @@ def _predict(self, i, x, period, **kwargs):
keep_last_seconds=self.keep_last_seconds,
tol_period=self.tol_period, period=period)
p = self.models[i].predict(x_i, quantiles=list(kwargs['quantiles']) if 'quantiles' in kwargs else 'mean')
if len(p.shape) == 1:
p = np.expand_dims(p, 1)
return p

def predict_single(self, i, x, quantiles='mean', add_step=True):
if add_step:
x = pd.concat([x.reset_index(drop=True), pd.Series(np.ones(len(x)) * i, name='sa')], axis=1)
return self.multi_step_model.predict(x, quantiles)
preds = self.multi_step_model.predict(x, quantiles)
if len(preds.shape) == 1:
preds = np.expand_dims(preds, 1)
return preds

def predict_parallel(self, x, quantiles='mean', add_step=True):
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_parallel_workers) as executor:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
import numpy as np
import logging

from ray.tune import quniform

from pyforecaster.forecasting_models.holtwinters import HoltWinters, HoltWintersMulti
from pyforecaster.forecasting_models.fast_adaptive_models import Fourier_es, FK, FK_multi
from pyforecaster.forecasting_models.random_fourier_features import RFFRegression, AdditiveRFFRegression, BrutalRegressor
Expand Down Expand Up @@ -212,7 +215,8 @@ def test_qrf(self):

qrf = QRF(val_ratio=0.2, formatter=formatter, n_jobs=4, n_single=2).fit(x_tr, y_tr)
y_hat = qrf.predict(x_te)
q = qrf.predict_quantiles(x_te)
q = qrf.predict_quantiles(x_te, quantiles=[0.1, 0.9])
q = qrf.predict_quantiles(x_te, quantiles=[0.5])

#plot_quantiles([y_te, y_hat], q, ['y_te', 'y_hat', 'y_hat_qrf'])

Expand Down

0 comments on commit a1e2d74

Please sign in to comment.