diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index cdccc3b..7bbb54c 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -40,14 +40,20 @@ def __call__(self, x): x = nn.Dense(features=self.n_out, name='dense')(x) return x -def reproject_weights(params, rec_stable=False): +def reproject_weights(params, rec_stable=False, monotone=False): # Loop through each layer and reproject the input-convex weights for layer_name in params['params']: if 'PICNNLayer' in layer_name: - params['params'][layer_name]['wz']['kernel'] = jnp.maximum(0, params['params'][layer_name]['wz']['kernel']) - if rec_stable: - params['params'][layer_name]['wy']['kernel'] = jnp.maximum(0, params['params'][layer_name]['wy']['kernel']) - + if monotone: + for name in {'wz', 'wy'} & set(params['params'][layer_name].keys()): + params['params'][layer_name][name]['kernel'] = jnp.maximum(0, params['params'][layer_name][name]['kernel']) + #for name in {'wzu', 'wyu', 'wuz', 'u_dense'} & set(params['params'][layer_name].keys()): + # params['params'][layer_name][name]['bias'] = jnp.maximum(0, params['params'][layer_name][name][ + # 'bias']) + else: + params['params'][layer_name]['wz']['kernel'] = jnp.maximum(0, params['params'][layer_name]['wz']['kernel']) + if rec_stable: + params['params'][layer_name]['wy']['kernel'] = jnp.maximum(0, params['params'][layer_name]['wy']['kernel']) return params @@ -104,14 +110,13 @@ def __call__(self, y, u, z): y = jnp.hstack([y, -y]) # Input-Convex component without bias for the element-wise multiplicative interactions - wzu = self.rec_activation(nn.Dense(features=self.features_out, use_bias=True, name='wzu')(u)) - wzu = self.activation(wzu) + wzu = nn.relu(nn.Dense(features=self.features_out, use_bias=True, name='wzu')(u)) wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u)) - z_next = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu) y_next = nn.Dense(features=self.features_out, use_bias=False, name='wy')(y * wyu) u_add = nn.Dense(features=self.features_out, use_bias=True, name='wuz')(u) + if self.layer_normalization: y_next = nn.LayerNorm()(y_next) z_next = nn.LayerNorm()(z_next) @@ -151,6 +156,7 @@ class NN(ScenarioGenerator): causal_df: pd.DataFrame = None reproject: bool = False rec_stable: bool = False + monotone: bool = False def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, @@ -267,7 +273,7 @@ def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step pars, opt_state, values = self.train_step(pars, opt_state, inputs_batch, targets_batch) if self.reproject: - pars = reproject_weights(pars, rec_stable=self.rec_stable) + pars = reproject_weights(pars, rec_stable=self.rec_stable, monotone=self.monotone) if k % stats_step == 0 and k > 0: self.pars = pars @@ -372,7 +378,7 @@ def __call__(self, x, y): return z -class PartiallyQICNN(nn.Module): +class PartiallyIQCNN(nn.Module): num_layers: int features_x: int features_y: int @@ -383,22 +389,22 @@ class PartiallyQICNN(nn.Module): augment_ctrl_inputs: bool = False layer_normalization:bool = False + def __call_wrapper__(self, y, x): + u = x + z = jnp.zeros(self.features_out) # Initialize z_0 to be the same shape as y + for i in range(self.num_layers): + prediction_layer = i == self.num_layers -1 + u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=self.features_out, + n_layer=i, prediction_layer=prediction_layer, activation=self.activation, + init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, + layer_normalization=self.layer_normalization)(y, u, z) + return z + @nn.compact def __call__(self, x, y): - _, qcvx_preds = jvp(lambda y: PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, - features_y=self.features_y, features_out=self.features_out, - activation=self.activation, rec_activation=self.rec_activation, - init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, - layer_normalization=self.layer_normalization)(x, y), - (y,), - (jnp.ones_like(y)/y.shape[0],)) - - qcvx_preds = jnp.abs(qcvx_preds) - + qcvx_preds = jnp.abs(jvp(partial(self.__call_wrapper__, x=x),(y, ), (y,))[1]) ex_preds = FeedForwardModule(n_layers=self.num_layers, n_neurons=self.features_x, n_out=self.features_out)(x) - return qcvx_preds + ex_preds @@ -562,7 +568,7 @@ def iterate(x, y, opt_state, **objective_kwargs): y_opt = inputs.loc[:, self.optimization_vars].values.ravel() return y_opt, inputs, target_opt, values -class PQICNN(PICNN): +class PIQCNN(PICNN): reproject: bool = True rec_stable: bool = False def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, @@ -582,10 +588,38 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat def set_arch(self): - model = PartiallyQICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, features_out=self.n_out, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs) return model + +class PIQCNNSigmoid(PICNN): + reproject: bool = True + rec_stable: bool = False + monotone: bool = True + + def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, + n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, + stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, + stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, + inverter_learning_rate: float = 0.1, optimization_vars: list = (), + target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, + optimizer=None, **scengen_kwgs): + + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, + normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, + inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs, + layer_normalization, optimizer, **scengen_kwgs) + + + def set_arch(self): + model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + features_out=self.n_out, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, + rec_activation=nn.sigmoid) + return model class RecStablePICNN(PICNN): reproject: bool = True rec_stable = True diff --git a/tests/test_nns.py b/tests/test_nns.py index b78d5c1..eea7bd3 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -4,7 +4,7 @@ import pandas as pd import numpy as np import logging -from pyforecaster.forecasting_models.neural_forecasters import PQICNN, PICNN, RecStablePICNN, NN +from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid from pyforecaster.trainer import hyperpar_optimizer from pyforecaster.formatter import Formatter from pyforecaster.metrics import nmae @@ -142,7 +142,7 @@ def test_pqicnn(self): x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(), y.iloc[n_tr:].copy()] - savepath_tr_plots = 'wp3/tests/results/figs/convexity' + savepath_tr_plots = 'tests/results/figs/convexity' # if not there, create directory savepath_tr_plots if not exists(savepath_tr_plots): @@ -152,7 +152,7 @@ def test_pqicnn(self): optimization_vars = x_tr.columns[:10] optimizer = optax.adamw(learning_rate=1e-3) - m_1 = PQICNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200, n_hidden_y=200, + m_1 = PIQCNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200, n_hidden_y=200, n_out=y_tr.shape[1], n_layers=4, optimization_vars=optimization_vars,stopping_rounds=100, optimizer=optimizer).fit(x_tr, y_tr, n_epochs=1, @@ -177,6 +177,54 @@ def test_pqicnn(self): plt.xlabel(cc) plt.savefig(join(savepath_tr_plots, '{}.png'.format(cc)), dpi=300) + def test_piqcnn_sigmoid(self): + + # normalize inputs + x_cols = ['all_lag_101', 'all_lag_062', 'ghi_6', 'all_lag_138', + 'all_lag_090'] + #x_cols = np.random.choice(self.x.columns, 5) + x = (self.x[x_cols] - self.x[x_cols].mean(axis=0)) / (self.x[x_cols].std(axis=0)+0.01) + y = (self.y - self.y.mean(axis=0)) / (self.y.std(axis=0)+0.01) + + n_tr = int(len(x) * 0.8) + x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(), + y.iloc[n_tr:].copy()] + + savepath_tr_plots = 'tests/results/figs/convexity' + + # if not there, create directory savepath_tr_plots + if not exists(savepath_tr_plots): + makedirs(savepath_tr_plots) + + + optimization_vars = x_tr.columns[:-2] + optimizer = optax.adamw(learning_rate=1e-3) + + m_1 = PIQCNNSigmoid(learning_rate=1e-2, batch_size=200, load_path=None, n_hidden_x=200, n_hidden_y=200, + n_out=y_tr.shape[1], n_layers=4, optimization_vars=optimization_vars,stopping_rounds=100, optimizer=optimizer, layer_normalization=True).fit(x_tr, + y_tr, + n_epochs=1, + savepath_tr_plots=savepath_tr_plots, + stats_step=200,rel_tol=-1) + + # check convexity of the PICNN + rnd_idxs = np.random.choice(x_tr.shape[0], 1) + rand_opt_vars = np.random.choice(optimization_vars, 5) + for cc in rand_opt_vars: + x = x_tr.iloc[rnd_idxs, :] + x = pd.concat([x] * 100, axis=0) + x[cc] = np.linspace(-1, 1, 100) + y_hat = m_1.predict(x) + d = np.diff(np.sign(np.diff(y_hat.values, axis=0)), axis=0) + approx_second_der = np.round(np.diff(y_hat.values, 2, axis=0), 5) + approx_second_der[approx_second_der == 0] = 0 # to fix the sign + is_convex = not np.any(np.abs(np.diff(np.sign(approx_second_der), axis=0)) > 1) + print('output is convex w.r.t. input {}: {}'.format(cc, is_convex)) + plt.figure(layout='tight') + plt.plot(np.tile(x[cc].values.reshape(-1, 1), y_hat.shape[1]), y_hat.values, alpha=0.3) + plt.xlabel(cc) + plt.savefig(join(savepath_tr_plots, '{}.png'.format(cc)), dpi=300) + def test_optimization(self): # normalize inputs