diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index f86c9f3..c0fd1b3 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -65,6 +65,12 @@ def loss_fn(params, inputs, targets, model=None): predictions = model(params, inputs) return jnp.mean((predictions - targets) ** 2) +def embedded_loss_fn(params, inputs, targets, model=None): + predictions, ctrl_embedding, ctrl_reconstruction = model(params, inputs) + target_loss = jnp.mean((predictions - targets) ** 2) + ctrl_reconstruction_loss = jnp.mean((ctrl_reconstruction - inputs[1]) ** 2) + return target_loss + ctrl_reconstruction_loss + def probabilistic_loss(y_hat, y, sigma_square, kind='maximum_likelihood', distribution='normal'): if kind == 'maximum_likelihood': @@ -123,6 +129,19 @@ def predict_batch(pars, inputs, model=None): def predict_batch_picnn(pars, inputs, model=None): return model.apply(pars, *inputs) +def predict_batch_latent_picnn(pars, inputs, model=None, mode='all'): + z, ctrl_embedding, ctrl_reconstruction = model.apply(pars, *inputs) + if mode == 'all': + return z, ctrl_embedding, ctrl_reconstruction + elif mode == 'prediction': + return z + elif mode == 'embedding': + return ctrl_embedding + elif mode == 'reconstruction': + return ctrl_reconstruction + + + class FeedForwardModule(nn.Module): n_layers: Union[int, np.array, list] n_out: int=None @@ -145,6 +164,7 @@ class PICNNLayer(nn.Module): features_x: int features_y: int features_out: int + features_latent: int n_layer: int = 0 prediction_layer: bool = False activation: callable = nn.relu @@ -160,7 +180,7 @@ def __call__(self, y, u, z): y = jnp.hstack([y, -y]) # Input-Convex component without bias for the element-wise multiplicative interactions - wzu = nn.relu(nn.Dense(features=self.features_out, use_bias=True, name='wzu')(u)) + wzu = nn.relu(nn.Dense(features=self.features_latent, use_bias=True, name='wzu')(u)) wyu = self.rec_activation(nn.Dense(features=self.features_y, use_bias=True, name='wyu')(u)) z_next = nn.Dense(features=self.features_out, use_bias=False, name='wz', kernel_init=partial(positive_lecun, init_type=self.init_type))(z * wzu) y_next = nn.Dense(features=self.features_out, use_bias=False, name='wy')(y * wyu) @@ -463,6 +483,7 @@ class PartiallyICNN(nn.Module): features_x: int features_y: int features_out: int + features_latent: int activation: callable = nn.relu rec_activation: callable = identity init_type: str = 'normal' @@ -475,10 +496,12 @@ class PartiallyICNN(nn.Module): @nn.compact def __call__(self, x, y): u = x.copy() - z = jnp.zeros(self.features_out) # Initialize z_0 to be the same shape as y + z = jnp.zeros(self.features_latent) # Initialize z_0 to be the same shape as y for i in range(self.num_layers): prediction_layer = i == self.num_layers -1 - u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=self.features_out, + features_out = self.features_out if prediction_layer else self.features_latent + u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=features_out, + features_latent=self.features_latent, n_layer=i, prediction_layer=prediction_layer, activation=self.activation, rec_activation=self.rec_activation, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, @@ -492,6 +515,7 @@ def __call__(self, x, y): prediction_layer = i == self.num_layers - 1 u, sigma = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=sigma_len, + features_latent=self.features_latent, n_layer=i, prediction_layer=prediction_layer, activation=self.activation, rec_activation=self.rec_activation, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, @@ -503,17 +527,71 @@ def __call__(self, x, y): return z +class LatentPartiallyICNN(nn.Module): + num_layers: int + features_x: int + features_y: int + features_out: int + features_latent: int + encoder_neurons: np.array = None + decoder_neurons: np.array = None + activation: callable = nn.relu + rec_activation: callable = identity + init_type: str = 'normal' + augment_ctrl_inputs: bool = False + layer_normalization:bool = False + probabilistic: bool = False + structured: bool = False + z_max: jnp.array = None + z_min: jnp.array = None + + + def setup(self): + ctrl_embedding_len = self.encoder_neurons[-1] + features_y = 2*ctrl_embedding_len if self.augment_ctrl_inputs else ctrl_embedding_len + + self.encoder = FeedForwardModule(n_layers=self.encoder_neurons, name='encoder') + self.decoder = FeedForwardModule(n_layers=self.decoder_neurons, name='decoder') + + + self.picnn = PartiallyICNN(num_layers=self.num_layers, features_x=self.features_x, features_y=features_y, + features_out=self.features_out, features_latent=self.features_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=self.z_min, z_max=self.z_max, name='picnn') + + def __call__(self, x, y): + + ctrl_embedding = self.encoder(jnp.hstack([x, y])) + z = self.picnn(x, ctrl_embedding) + ctrl_reconstruction = self.decoder(ctrl_embedding) + + + return z, ctrl_embedding, ctrl_reconstruction + + def decode(self, ctrl_embedding): + return self.decoder(ctrl_embedding) + + +def decode(params, model, ctrl_embedding): + def decoder(lpicnn ): + return lpicnn.decode(ctrl_embedding) + + return nn.apply(decoder, model)(params) + class PartiallyIQCNN(nn.Module): num_layers: int features_x: int features_y: int features_out: int + features_latent: int activation: callable = nn.softplus rec_activation: callable = identity init_type: str = 'normal' augment_ctrl_inputs: bool = False layer_normalization:bool = False probabilistic: bool = False + z_max: jnp.array = None + z_min: jnp.array = None def __call_wrapper__(self, y, x): u = x @@ -521,9 +599,11 @@ def __call_wrapper__(self, y, x): for i in range(self.num_layers): prediction_layer = i == self.num_layers -1 u, z = PICNNLayer(features_x=self.features_x, features_y=self.features_y, features_out=self.features_out, + features_latent=self.features_latent, n_layer=i, prediction_layer=prediction_layer, activation=self.activation, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, - layer_normalization=self.layer_normalization)(y, u, z) + layer_normalization=self.layer_normalization, z_min=self.z_min, + z_max=self.z_max)(y, u, z) return z @nn.compact @@ -585,8 +665,9 @@ class PICNN(NN): distribution = 'normal' z_min: jnp.array = None z_max: jnp.array = None + n_latent: int = 1 def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -595,6 +676,7 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat z_min: jnp.array = None, z_max: jnp.array = None, **scengen_kwgs): + self.set_attr({"inverter_learning_rate":inverter_learning_rate, "optimization_vars":optimization_vars, "target_columns":target_columns, @@ -604,7 +686,8 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat "probabilistic_loss_kind":probabilistic_loss_kind, "distribution": distribution, "z_min": z_min, - "z_max": z_max + "z_max": z_max, + "n_latent":n_latent }) self.n_hidden_y = 2 * len(self.optimization_vars) if augment_ctrl_inputs else len(self.optimization_vars) self.inverter_optimizer = optax.adabelief(learning_rate=self.inverter_learning_rate) @@ -616,11 +699,13 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, z_min=self.z_min, - z_max=self.z_max) + features_out=self.n_out, features_latent=self.n_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=z_min, z_max=z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) if self.causal_df is not None: @@ -668,7 +753,7 @@ def _objective(y, x, **objective_kwargs): return objective(self.predict_batch(self.pars, [x, y]), y, **objective_kwargs) # if the objective changes from one call to another, you need to recompile it. Slower but necessary - if recompile_obj: + if recompile_obj or self.iterate is None: @jit def iterate(x, y, opt_state, **objective_kwargs): for i in range(10): @@ -680,20 +765,6 @@ def iterate(x, y, opt_state, **objective_kwargs): y = optax.apply_updates(y, updates) return y, values self.iterate = iterate - else: - if self.iterate is None: - @jit - def iterate(x, y, opt_state, **objective_kwargs): - for i in range(10): - values, grads = value_and_grad(partial(_objective, **objective_kwargs))(y, x) - if vanilla_gd: - y -= grads * 1e-1 - else: - updates, opt_state = self.inverter_optimizer.update(grads, opt_state, y) - y = optax.apply_updates(y, updates) - return y, values - - self.iterate = iterate opt_state = self.inverter_optimizer.init(y) y, values_old = self.iterate(x, y, opt_state, **objective_kwargs) @@ -723,7 +794,7 @@ class PIQCNN(PICNN): reproject: bool = True rec_stable: bool = False def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -733,7 +804,7 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat z_min: jnp.array = None, z_max: jnp.array = None, **scengen_kwgs): - super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_latent, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, distribution, inverter_learning_rate, optimization_vars, target_columns, init_type, @@ -741,10 +812,13 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyIQCNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, init_type=self.init_type, - augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic) + features_out=self.n_out, features_latent=self.n_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=z_min, z_max=z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) @@ -757,7 +831,7 @@ class PIQCNNSigmoid(PICNN): monotone: bool = True def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -766,7 +840,7 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat z_min: jnp.array = None, z_max: jnp.array = None, **scengen_kwgs): - super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_latent, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, distribution, inverter_learning_rate, optimization_vars, target_columns, init_type, @@ -774,12 +848,14 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, init_type=self.init_type, + features_out=self.n_out, features_latent=self.n_latent, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, - rec_activation=nn.sigmoid, probabilistic=self.probabilistic,z_min=self.z_min, - z_max=self.z_max) + rec_activation=nn.sigmoid, probabilistic=self.probabilistic,z_min=z_min, + z_max=z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) @@ -789,7 +865,7 @@ class RecStablePICNN(PICNN): reproject: bool = True rec_stable = True def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -797,18 +873,20 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, z_min: jnp.array = None, z_max: jnp.array = None, **scengen_kwgs): - super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_latent, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, distribution, inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs, layer_normalization, z_min, z_max, **scengen_kwgs) def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, activation=nn.relu, - init_type=self.init_type, probabilistic=self.probabilistic,z_min=self.z_min, - z_max=self.z_max) + features_out=self.n_out, features_latent=self.n_latent, activation=nn.relu, + init_type=self.init_type, probabilistic=self.probabilistic,z_min=z_min, + z_max=z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) self.loss_fn = jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) if self.probabilistic else ( jitting_wrapper(loss_fn, self.predict_batch)) @@ -876,7 +954,7 @@ class StructuredPICNN(PICNN): monotone: bool = True objective_fun=None def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, - n_hidden_x: int = 100, n_out: int = None, n_layers: int = 3, pars: dict = None, q_vect=None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, @@ -887,19 +965,21 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat self.objective_fun = objective_fun self.objective = vmap(objective_fun, in_axes=(0, 0)) - super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_layers, pars, q_vect, val_ratio, + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_latent, n_layers, pars, q_vect, val_ratio, nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, probabilistic_loss_kind, distribution, inverter_learning_rate, optimization_vars, target_columns, init_type, augment_ctrl_inputs, layer_normalization, z_min, z_max, **scengen_kwgs) def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min self.optimizer = optax.adamw(learning_rate=self.learning_rate) self.model = PartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, - features_out=self.n_out, init_type=self.init_type, + features_out=self.n_out, features_latent=self.n_latent, init_type=self.init_type, augment_ctrl_inputs=self.augment_ctrl_inputs, activation=nn.sigmoid, rec_activation=nn.sigmoid, probabilistic=self.probabilistic, structured=True, - z_min=self.z_min, z_max=self.z_max) + z_min=z_min, z_max=z_max) self.predict_batch = vmap(jitting_wrapper(predict_batch_picnn, self.model), in_axes=(None, 0)) self.loss_fn = jitting_wrapper(structured_loss_fn, self.predict_batch, objective=self.objective) if not self.probabilistic \ else jitting_wrapper(structured_probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, objective=self.objective, distribution=self.distribution) @@ -972,4 +1052,104 @@ def predict_quantiles(self, inputs, normalize=True, **kwargs): neg_qs = 2 * jnp.exp(mu_hat) - np.exp(qn) preds[:, :, i] = (s==-1)*neg_qs + (s==1)*pos_qs - return preds \ No newline at end of file + return preds + + + +class LatentStructuredPICNN(PICNN): + reproject: bool = True + rec_stable: bool = False + monotone: bool = True + objective_fun=None + def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_path: str = None, + n_hidden_x: int = 100, n_out: int = 1, n_latent:int = 1, n_layers: int = 3, pars: dict = None, q_vect=None, + val_ratio=None, nodes_at_step=None, n_epochs: int = 10, savepath_tr_plots: str = None, + stats_step: int = 50, rel_tol: float = 1e-4, unnormalized_inputs=None, normalize_target=True, + stopping_rounds=5, subtract_mean_when_normalizing=False, causal_df=None, probabilistic=False, + probabilistic_loss_kind='maximum_likelihood', distribution='normal', inverter_learning_rate: float = 0.1, optimization_vars: list = (), + target_columns: list = None, init_type='normal', augment_ctrl_inputs=False, layer_normalization=False, + objective_fun=None, z_min: jnp.array = None, z_max: jnp.array = None, + n_first_encoder:int=10, n_last_encoder:int=10, n_encoder_layers:int=3, + n_first_decoder:int=10, n_decoder_layers:int=3, + **scengen_kwgs): + + self.set_attr({"encoder_neurons":np.linspace(n_first_encoder, n_last_encoder, n_encoder_layers).astype(int), + "decoder_neurons":np.linspace(n_first_decoder, len(optimization_vars), n_decoder_layers).astype(int), + "n_first_encoder":n_first_encoder, + "n_last_encoder":n_last_encoder, + "n_encoder_layers":n_encoder_layers, + "n_first_decoder":n_first_decoder, + "n_decoder_layers":n_decoder_layers + }) + + super().__init__(learning_rate, batch_size, load_path, n_hidden_x, n_out, n_latent, n_layers, pars, q_vect, val_ratio, + nodes_at_step, n_epochs, savepath_tr_plots, stats_step, rel_tol, unnormalized_inputs, + normalize_target, stopping_rounds, subtract_mean_when_normalizing, causal_df, probabilistic, + probabilistic_loss_kind, distribution, inverter_learning_rate, optimization_vars, target_columns, init_type, + augment_ctrl_inputs, layer_normalization, z_min, z_max, **scengen_kwgs) + def set_arch(self): + z_max = jnp.array(self.z_max) if self.z_max is not None else self.z_max + z_min = jnp.array(self.z_min) if self.z_min is not None else self.z_min + self.optimizer = optax.adamw(learning_rate=self.learning_rate) + self.model = LatentPartiallyICNN(num_layers=self.n_layers, features_x=self.n_hidden_x, features_y=self.n_hidden_y, + features_out=self.n_out, features_latent=self.n_latent, init_type=self.init_type, + augment_ctrl_inputs=self.augment_ctrl_inputs, probabilistic=self.probabilistic, + z_min=z_min, z_max=z_max, encoder_neurons=self.encoder_neurons, decoder_neurons=self.decoder_neurons) + + self.predict_batch_training = vmap(jitting_wrapper(predict_batch_latent_picnn, self.model, mode='all'), in_axes=(None, 0)) + self.predict_batch = vmap(jitting_wrapper(predict_batch_latent_picnn, self.model, mode='prediction'), in_axes=(None, 0)) + self.loss_fn = jitting_wrapper(embedded_loss_fn, self.predict_batch_training) if not self.probabilistic else jitting_wrapper(probabilistic_loss_fn, self.predict_batch, kind=self.probabilistic_loss_kind, distribution=self.distribution) + self.train_step = jitting_wrapper(partial(train_step, loss_fn=self.loss_fn), self.optimizer) + + + def optimize(self, inputs, objective, n_iter=200, rel_tol=1e-4, recompile_obj=True, vanilla_gd=False, **objective_kwargs): + rel_tol = rel_tol if rel_tol is not None else self.rel_tol + inputs = inputs.copy() + normalized_inputs, _ = self.get_normalized_inputs(inputs) + x, y = normalized_inputs + + def _objective(ctrl_embedding, x, **objective_kwargs): + ctrl = decode(self.pars, self.model, ctrl_embedding) + return objective(self.predict_batch(self.pars, [x, ctrl_embedding]), ctrl, **objective_kwargs) + + # if the objective changes from one call to another, you need to recompile it. Slower but necessary + if recompile_obj or self.iterate is None: + @jit + def iterate(x, y, opt_state, **objective_kwargs): + for i in range(10): + values, grads = value_and_grad(partial(_objective, **objective_kwargs))(y, x) + if vanilla_gd: + y -= grads * 1e-1 + else: + updates, opt_state = self.inverter_optimizer.update(grads, opt_state, y) + y = optax.apply_updates(y, updates) + return y, values + self.iterate = iterate + + opt_state = self.inverter_optimizer.init(y) + _, ctrl_embedding, ctrl_reconstruct = self.predict_batch_training(self.pars, [x, y]) + ctrl_embedding, values_old = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) + values_init = np.copy(values_old) + + + # do 10 iterations at a time to speed up, check for convergence + for i in range(n_iter//10): + ctrl_embedding, values = self.iterate(x, ctrl_embedding, opt_state, **objective_kwargs) + rel_improvement = (values_old - values) / (np.abs(values_old)+ 1e-12) + values_old = values + if rel_improvement < rel_tol: + break + print('optimization terminated at iter {}, final objective value: {:0.2e} ' + 'rel improvement: {:0.2e}'.format((i+1)*10, values, + (values_init-values)/(np.abs(values_init)+1e-12))) + + + y = decode(self.pars, self.model, ctrl_embedding) + + inputs.loc[:, self.optimization_vars] = y.ravel() + inputs.loc[:, [c for c in inputs.columns if c not in self.optimization_vars]] = x.ravel() + inputs.loc[:, self.to_be_normalized] = self.input_scaler.inverse_transform(inputs[self.to_be_normalized].values) + target_opt = self.predict(inputs) + + y_opt = inputs.loc[:, self.optimization_vars].values.ravel() + return y_opt, inputs, target_opt, values \ No newline at end of file diff --git a/tests/test_nns.py b/tests/test_nns.py index c6360d2..de8a872 100644 --- a/tests/test_nns.py +++ b/tests/test_nns.py @@ -4,7 +4,7 @@ import pandas as pd import numpy as np import logging -from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid, StructuredPICNN +from pyforecaster.forecasting_models.neural_forecasters import PICNN, RecStablePICNN, NN, PIQCNN, PIQCNNSigmoid, StructuredPICNN, LatentStructuredPICNN from pyforecaster.trainer import hyperpar_optimizer from pyforecaster.formatter import Formatter from pyforecaster.metrics import nmae @@ -351,6 +351,45 @@ def test_structured_picnn_sigmoid(self): plt.legend() + def test_latent_picnn(self): + + # normalize inputs + x = (self.x - self.x.mean(axis=0)) / (self.x.std(axis=0)+0.01) + y = (self.y - self.y.mean(axis=0)) / (self.y.std(axis=0)+0.01) + + n_tr = int(len(x) * 0.8) + x_tr, x_te, y_tr, y_te = [x.iloc[:n_tr, :].copy(), x.iloc[n_tr:, :].copy(), y.iloc[:n_tr].copy(), + y.iloc[n_tr:].copy()] + + savepath_tr_plots = 'tests/results/ffnn_tr_plots' + + # if not there, create directory savepath_tr_plots + if not exists(savepath_tr_plots): + makedirs(savepath_tr_plots) + + optimization_vars = x_tr.columns[:100] + + m = LatentStructuredPICNN(learning_rate=1e-3, batch_size=1000, load_path=None, n_hidden_x=200, + n_out=y_tr.shape[1], n_layers=3, optimization_vars=optimization_vars, inverter_learning_rate=1e-3, + augment_ctrl_inputs=True, layer_normalization=True, unnormalized_inputs=optimization_vars, + n_first_encoder=20, n_last_encoder=100, n_first_decoder=100).fit(x_tr, y_tr, + n_epochs=1, + savepath_tr_plots=savepath_tr_plots, + stats_step=40 ) + + objective = lambda y_hat, ctrl: jnp.mean(y_hat ** 2) + 0.0001*jnp.sum(ctrl**2) + ctrl_opt, inputs_opt, y_hat_opt, v_opt = m.optimize(x_te.iloc[[100], :], objective=objective,n_iter=5000) + + plt.plot(y_hat_opt.values.ravel()) + rnd_idxs = np.random.choice(x_te.shape[0], 1) + rnd_idxs = [100] + for r in rnd_idxs: + y_hat = m.predict(x_te.iloc[[r], :]) + plt.figure() + plt.plot(y_te.iloc[r, :].values.ravel(), label='y_true') + plt.plot(y_hat.values.ravel(), label='y_hat') + plt.legend() + if __name__ == '__main__': unittest.main()