Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nepslor committed Feb 27, 2024
1 parent 7044b4e commit f43141f
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions pyforecaster/forecasting_models/neural_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,49 @@ def __call__(self, y, u, z):
z_next = nn.sigmoid(z_next) * (self.z_max - self.z_min) + self.z_min
return None, z_next

class CausalInvertibleLayer(nn.Module):
features: int
negative_slope:int = 0.01
activation: callable = nn.leaky_relu(negative_slope=negative_slope)
init_type: str = 'normal'
layer_normalization: bool = False
prediction_layer: bool = False

def setup(self):
pass
def __call__(self, inputs):
inner_pars = self.param(
'kernel',
self.kernel_init,
(jnp.shape(inputs)[-1], self.features),
self.param_dtype,
)

kernel = jnp.tril(inner_pars) + jnp.eye(self.features)

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
else:
dot_general = self.dot_general
y = dot_general(
inputs,
kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)

if self.prediction_layer:
return y
else:
if self.layer_normalization:
y = nn.LayerNorm()(y)
return self.activation(y)

def invert(self, y):
if not self.prediction_layer:
y = self.inverse_leaky_relu(y)
return jnp.dot(y, jnp.linalg.inv(self.kernel))


def inverse_leaky_relu(self, y):
return jnp.where(y >= 0, y, y / self.negative_slope)

0 comments on commit f43141f

Please sign in to comment.