From f43141f12e8d8888b2b419eb77c86d2df567ea4b Mon Sep 17 00:00:00 2001 From: nepslor Date: Tue, 27 Feb 2024 14:28:59 +0100 Subject: [PATCH] refactoring --- .../neural_models/layers.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/pyforecaster/forecasting_models/neural_models/layers.py b/pyforecaster/forecasting_models/neural_models/layers.py index 0bd921d..44dd257 100644 --- a/pyforecaster/forecasting_models/neural_models/layers.py +++ b/pyforecaster/forecasting_models/neural_models/layers.py @@ -50,3 +50,49 @@ def __call__(self, y, u, z): z_next = nn.sigmoid(z_next) * (self.z_max - self.z_min) + self.z_min return None, z_next +class CausalInvertibleLayer(nn.Module): + features: int + negative_slope:int = 0.01 + activation: callable = nn.leaky_relu(negative_slope=negative_slope) + init_type: str = 'normal' + layer_normalization: bool = False + prediction_layer: bool = False + + def setup(self): + pass + def __call__(self, inputs): + inner_pars = self.param( + 'kernel', + self.kernel_init, + (jnp.shape(inputs)[-1], self.features), + self.param_dtype, + ) + + kernel = jnp.tril(inner_pars) + jnp.eye(self.features) + + if self.dot_general_cls is not None: + dot_general = self.dot_general_cls() + else: + dot_general = self.dot_general + y = dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + + if self.prediction_layer: + return y + else: + if self.layer_normalization: + y = nn.LayerNorm()(y) + return self.activation(y) + + def invert(self, y): + if not self.prediction_layer: + y = self.inverse_leaky_relu(y) + return jnp.dot(y, jnp.linalg.inv(self.kernel)) + + + def inverse_leaky_relu(self, y): + return jnp.where(y >= 0, y, y / self.negative_slope)