diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index 492e337..15ee6cb 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -589,9 +589,9 @@ def decoder(lpicnn ): return nn.apply(decoder, model)(params) -def latent_pred(params, model, ctrl_embedding, x): +def latent_pred(params, model, x, ctrl_embedding): def _latent_pred(lpicnn ): - return lpicnn.latent_pred(ctrl_embedding, x) + return lpicnn.latent_pred(x, ctrl_embedding) return nn.apply(_latent_pred, model)(params) @@ -1133,8 +1133,8 @@ def _objective(ctrl_embedding, x, **objective_kwargs): ctrl_reconstruct = decode(self.pars, self.model, x, ctrl_embedding) preds_reconstruct, _ , _ = self.predict_batch_training(self.pars, [x, ctrl_reconstruct]) implicit_regularization_loss = jnp.mean((preds_reconstruct - preds)**2) - return objective(preds, ctrl_embedding, **objective_kwargs) + implicit_regularization_loss - + return objective(preds, ctrl_reconstruct, **objective_kwargs) + implicit_regularization_loss + self._objective = _objective # if the objective changes from one call to another, you need to recompile it. Slower but necessary if recompile_obj or self.iterate is None: @jit diff --git a/tests/test_nns.py b/tests/test_nns.py index 8170d59..d807e9d 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -377,12 +377,13 @@ def test_latent_picnn(self): savepath_tr_plots=savepath_tr_plots, stats_step=40) - objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) - m.inverter_optimizer = optax.adabelief(learning_rate=1e-3) - ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=500) + objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + boxconstr(ctrl, 100, -100) + m.inverter_optimizer = optax.adabelief(learning_rate=1e-1) + ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[0], :], objective=objective,n_iter=500) + #convexity_test(x_te, m, optimization_vars) rnd_idxs = np.random.choice(x_te.shape[0], 1) - rnd_idxs = [100] + rnd_idxs = [0] for r in rnd_idxs: y_hat = m.predict(x_te.iloc[[r], :]) plt.figure() @@ -390,6 +391,30 @@ def test_latent_picnn(self): plt.plot(y_hat.values.ravel(), label='y_hat') plt.legend() plt.plot(y_hat_opt.values.ravel()) + plt.show() + +def boxconstr(x, ub, lb): + return jnp.sum(jnp.maximum(0, x - ub)**2 + jnp.maximum(0, lb - x)**2) + +def convexity_test(df, forecaster, ctrl_names, **objective_kwargs): + x_names = [c for c in df.columns if not c in ctrl_names] + rand_idxs = np.random.choice(len(df), 5) + for idx in rand_idxs: + x = df[x_names].iloc[idx, :] + ctrl_embedding = np.tile(np.random.randn(forecaster.n_last_encoder, 1), 100).T + + plt.figure() + for ctrl_e in range(forecaster.n_last_encoder): + ctrls = ctrl_embedding.copy() + ctrls[:, ctrl_e] = np.linspace(-1, 1, 100) + preds = np.hstack([forecaster._objective(c, np.atleast_2d(x.values.ravel()), **objective_kwargs) for c in ctrls]) + approx_second_der = np.round(np.diff(preds, 2, axis=0), 5) + approx_second_der[approx_second_der == 0] = 0 # to fix the sign + is_convex = np.all(np.sign(approx_second_der) >= 0) + print('output is convex w.r.t. input {}: {}'.format(ctrl_e, is_convex)) + plt.plot(ctrls[:, ctrl_e], preds, alpha=0.3) + plt.pause(0.001) + plt.show() if __name__ == '__main__': unittest.main()