Skip to content

Commit

Permalink
added restoring pars if val loss last step decreased
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Dec 14, 2023
1 parent ec0e2e4 commit f028517
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
11 changes: 7 additions & 4 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def probabilistic_loss_fn(params, inputs, targets, model=None):
out = model(params, inputs)
predictions = out[:, :out.shape[1]//2]
sigma_square = out[:, out.shape[1]//2:]
ll = jnp.mean((predictions - targets)**2 / sigma_square + jnp.log(sigma_square))
ll = jnp.mean(((predictions - targets)**2) / sigma_square + jnp.log(sigma_square))
return ll

def train_step(params, optimizer_state, inputs_batch, targets_batch, model=None, loss_fn=None, **kwargs):
Expand Down Expand Up @@ -289,6 +289,7 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step
pars = reproject_weights(pars, rec_stable=self.rec_stable, monotone=self.monotone)

if k % stats_step == 0 and k > 0:
old_pars = self.pars
self.pars = pars
rand_idx_val = np.random.choice(validation_len, np.minimum(batch_size, validation_len), replace=False)
inputs_val_sampled = [i[rand_idx_val, :] for i in inputs_val] if isinstance(inputs_val, tuple) else inputs_val[rand_idx_val, :]
Expand All @@ -314,7 +315,9 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step
k += 1
if finished:
break

if len(val_loss)>1:
if val_loss[-1] > val_loss[-2]:
pars = old_pars
self.pars = pars
super().fit(inputs_val_0, targets_val_0)
return self
Expand Down Expand Up @@ -418,7 +421,7 @@ def __call__(self, x, y):
augment_ctrl_inputs=self.augment_ctrl_inputs,
layer_normalization=self.layer_normalization)(y, u, z)
if self.probabilistic:
return jnp.hstack([z[:self.features_out//2], nn.softplus(z[self.features_out//2:]) + 1e-10])
return jnp.hstack([z[:self.features_out//2], nn.softplus(z[self.features_out//2:]) + 1e-8])
return z


Expand Down Expand Up @@ -485,7 +488,7 @@ def probabilistic_causal_loss_fn(params, inputs, targets, model=None, causal_mat
predictions = out[:, :out.shape[1]//2]
sigma_square = out[:, out.shape[1]//2:]
causal_loss = vmap(_my_jmp, in_axes=(None, None, 0, 0, None))(model, params, ex_inputs, ctrl_inputs, causal_matrix.T)
ll = jnp.mean((predictions - targets) ** 2 / sigma_square + jnp.log(sigma_square))
ll = (jnp.mean((predictions - targets) ** 2) / sigma_square + jnp.log(sigma_square))
return ll + jnp.mean(causal_loss)


Expand Down
16 changes: 9 additions & 7 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,24 @@ def test_picnn(self):
optimization_vars = x_tr.columns[:100]


m_1 = PICNN(learning_rate=1e-3, batch_size=5000, load_path=None, n_hidden_x=200, n_hidden_y=200,
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,probabilistic=True).fit(x_tr,
m_1 = PICNN(learning_rate=1e-3, batch_size=500, load_path=None, n_hidden_x=200, n_hidden_y=200,
n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars,probabilistic=True, rel_tol=-1,
val_ratio=0.2).fit(x_tr,
y_tr,
n_epochs=1,
stats_step=10)
stats_step=200,
savepath_tr_plots=savepath_tr_plots)

y_hat_1 = m_1.predict(x_te)
m_1.save('tests/results/ffnn_model.pk')

rnd_idxs = np.random.choice(x_tr.shape[0], 1)
rnd_idxs = np.random.choice(x_te.shape[0], 1)
for r in rnd_idxs:
y_hat = m_1.predict(x_tr.iloc[[r], :])
q_hat = m_1.predict_quantiles(x_tr.iloc[[r], :])
y_hat = m_1.predict(x_te.iloc[[r], :])
q_hat = m_1.predict_quantiles(x_te.iloc[[r], :])
plt.figure()
plt.plot(y_hat.values.ravel(), label='y_hat')
plt.plot(y_tr.iloc[r, :].values.ravel(), label='y_true')
plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true')
plt.plot(np.squeeze(q_hat), label='q_hat', color='red', alpha=0.2)
plt.legend()

Expand Down

0 comments on commit f028517

Please sign in to comment.