Skip to content

Commit

Permalink
exploit multi-target feature of RandomForestQuantileRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
vascomedici committed Sep 10, 2024
1 parent d522cab commit 5144493
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 34 deletions.
49 changes: 17 additions & 32 deletions pyforecaster/forecasting_models/randomforests.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,11 @@ def fit(self, x, y):
n_sa = y.shape[1]
self.n_multistep = n_sa - self.n_single
if self.n_multistep>0:
x_pd = x.copy()
x_pd.reset_index(drop=True, inplace=True)
red_frac = np.maximum(1, self.red_frac_multistep*self.n_multistep)/self.n_multistep
n_batch = int(len(x_pd)*red_frac)
n_long = n_batch*self.n_multistep
rand_idx = np.random.choice(x_pd.index, n_long).reshape(self.n_multistep, -1)
x_long = []
for sa in range(self.n_multistep):
x_i = pd.concat([x_pd.loc[rand_idx[sa], :].reset_index(drop=True), pd.Series(np.ones(n_batch) * sa, name='sa')], axis=1)
x_long.append(x_i)
x_long = pd.concat(x_long, axis=0)
y = y
y_long = []
for i in range(self.n_multistep):
y_long.append(y.iloc[rand_idx[i], i])
y_long = pd.concat(y_long)

t_0 = time()
qrf_pars_global = self.qrf_pars.copy()
if 'n_jobs' in qrf_pars_global and qrf_pars_global['n_jobs'] is not None and qrf_pars_global['n_jobs'] > 0:
qrf_pars_global['n_jobs'] *= self.max_parallel_workers
self.multi_step_model = RandomForestQuantileRegressor(**qrf_pars_global).fit(x_long, y_long, sparse_pickle=True)
self.multi_step_model = RandomForestQuantileRegressor(**qrf_pars_global).fit(x, y.iloc[:, -self.n_multistep:], sparse_pickle=True)
self.logger.info('QRF multistep fitted in {:0.2e} s, x shape: [{}, {}]'.format(time() - t_0,
x.shape[0],
x.shape[1]))
Expand All @@ -149,7 +132,7 @@ def predict(self, x, **kwargs):
preds.append(p)
x_pd = x
if self.n_multistep>0:
preds.append(self.predict_parallel(x_pd, quantiles=list(kwargs['quantiles']) if 'quantiles' in kwargs else 'mean'))
preds.append(self.predict_multi_step(x, quantiles=list(kwargs['quantiles']) if 'quantiles' in kwargs else 'mean'))
preds = np.dstack(preds)
if 'quantiles' in kwargs:
if str(kwargs['quantiles']) == 'mean':
Expand All @@ -173,21 +156,14 @@ def _predict(self, i, x, period, **kwargs):
p = np.expand_dims(p, 1)
return p

def predict_single(self, i, x, quantiles='mean', add_step=True):
if add_step:
x = pd.concat([x.reset_index(drop=True), pd.Series(np.ones(len(x)) * i, name='sa')], axis=1)
def predict_multi_step(self, x, quantiles='mean'):
preds = self.multi_step_model.predict(x, quantiles)
if len(preds.shape) == 1:
if len(preds.shape) == 2:
preds = np.expand_dims(preds, 1)
else:
preds = np.swapaxes(preds, 1, 2)
return preds

def predict_parallel(self, x, quantiles='mean', add_step=True):
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_parallel_workers) as executor:
y_hat = [i for i in
tqdm(executor.map(partial(self.predict_single, x=x, quantiles=quantiles, add_step=add_step), range(self.n_multistep)),
total=self.n_multistep)]
return np.dstack(y_hat)

@staticmethod
def dataset_at_stepahead(df, target_col_num, metadata_features, formatter, logger, method='periodic', keep_last_n_lags=1, period="24h",
tol_period='1h', keep_last_seconds=0):
Expand All @@ -199,7 +175,16 @@ def dataset_at_stepahead(df, target_col_num, metadata_features, formatter, logge
period=period, keep_last_n_lags=keep_last_n_lags, keep_last_seconds=keep_last_seconds,
tol_period=tol_period)

def predict_quantiles(self, x, **kwargs):
@staticmethod
def quantiles_to_numpy(q_hat: pd.DataFrame):
n_taus = len(q_hat.columns.get_level_values(1).unique())
q_hat = q_hat.values
q_hat = np.reshape(q_hat, (q_hat.shape[0], n_taus, -1))
q_hat = np.swapaxes(q_hat, 1, 2)
return q_hat

def predict_quantiles(self, x, dataframe=True, **kwargs):
preds = self.predict(x, quantiles=list(kwargs['quantiles']) if 'quantiles' in kwargs else self.q_vect)

if dataframe is False:
return self.quantiles_to_numpy(preds)
return preds
5 changes: 3 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,11 @@ def test_qrf(self):

qrf = QRF(val_ratio=0.2, formatter=formatter, n_jobs=4, n_single=2).fit(x_tr, y_tr)
y_hat = qrf.predict(x_te)
q = qrf.predict_quantiles(x_te, quantiles=[0.1, 0.9])
q = qrf.predict_quantiles(x_te, quantiles=[0.1, 0.5, 0.9])
#plot_quantiles([y_te, y_hat], q, ['y_te', 'y_hat', 'y_hat_qrf'])
q = qrf.predict_quantiles(x_te, quantiles=[0.5])

#plot_quantiles([y_te, y_hat], q, ['y_te', 'y_hat', 'y_hat_qrf'])


y_hat = qrf.predict(x_te.iloc[[0], :])
q = qrf.predict(x_te.iloc[[0], :])
Expand Down

0 comments on commit 5144493

Please sign in to comment.