Skip to content

Commit

Permalink
corrected latent_pred input order, added self._objective
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Jan 29, 2024
1 parent 3121bb8 commit 710d904
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pyforecaster/forecasting_models/neural_forecasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ def decoder(lpicnn ):

return nn.apply(decoder, model)(params)

def latent_pred(params, model, ctrl_embedding, x):
def latent_pred(params, model, x, ctrl_embedding):
def _latent_pred(lpicnn ):
return lpicnn.latent_pred(ctrl_embedding, x)
return lpicnn.latent_pred(x, ctrl_embedding)

return nn.apply(_latent_pred, model)(params)

Expand Down Expand Up @@ -1133,8 +1133,8 @@ def _objective(ctrl_embedding, x, **objective_kwargs):
ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding)
preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct])
implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2)
return objective(preds, ctrl_embedding, **objective_kwargs) + implicit_regularization_loss

return objective(preds, ctrl_reconstruct, **objective_kwargs) + implicit_regularization_loss
self._objective = _objective
# if the objective changes from one call to another, you need to recompile it. Slower but necessary
if recompile_obj or self.iterate is None:
@jit
Expand Down
33 changes: 29 additions & 4 deletions tests/test_nns.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,19 +377,44 @@ def test_latent_picnn(self):
savepath_tr_plots=savepath_tr_plots,
stats_step=40)

objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2)
m.inverter_optimizer = optax.adabelief(learning_rate=1e-3)
ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=500)
objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + boxconstr(ctrl, 100, -100)
m.inverter_optimizer = optax.adabelief(learning_rate=1e-1)
ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[0], :], objective=objective,n_iter=500)

#convexity_test(x_te, m, optimization_vars)
rnd_idxs = np.random.choice(x_te.shape[0], 1)
rnd_idxs = [100]
rnd_idxs = [0]
for r in rnd_idxs:
y_hat = m.predict(x_te.iloc[[r], :])
plt.figure()
plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true')
plt.plot(y_hat.values.ravel(), label='y_hat')
plt.legend()
plt.plot(y_hat_opt.values.ravel())
plt.show()

def boxconstr(x, ub, lb):
return jnp.sum(jnp.maximum(0, x - ub)**2 + jnp.maximum(0, lb - x)**2)

def convexity_test(df, forecaster, ctrl_names, **objective_kwargs):
x_names = [c for c in df.columns if not c in ctrl_names]
rand_idxs = np.random.choice(len(df), 5)
for idx in rand_idxs:
x = df[x_names].iloc[idx, :]
ctrl_embedding = np.tile(np.random.randn(forecaster.n_last_encoder, 1), 100).T

plt.figure()
for ctrl_e in range(forecaster.n_last_encoder):
ctrls = ctrl_embedding.copy()
ctrls[:, ctrl_e] = np.linspace(-1, 1, 100)
preds = np.hstack([forecaster._objective(c, np.atleast_2d(x.values.ravel()), **objective_kwargs) for c in ctrls])
approx_second_der = np.round(np.diff(preds, 2, axis=0), 5)
approx_second_der[approx_second_der == 0] = 0 # to fix the sign
is_convex = np.all(np.sign(approx_second_der) >= 0)
print('output is convex w.r.t. input {}: {}'.format(ctrl_e, is_convex))
plt.plot(ctrls[:, ctrl_e], preds, alpha=0.3)
plt.pause(0.001)
plt.show()

if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 710d904

Please sign in to comment.