diff --git a/gpjax/gps.py b/gpjax/gps.py index cca9ca36..015f3b5b 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -652,7 +652,8 @@ def __init__( """ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter) - latent = latent or jr.normal(key, shape=(self.likelihood.num_datapoints, 1)) + if latent is None: + latent = jr.normal(key, shape=(self.likelihood.num_datapoints, 1)) # TODO: static or intermediate? self.latent = latent if isinstance(latent, Parameter) else Real(latent) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 88c5f7b3..251b8756 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -149,12 +149,14 @@ def __init__( ): super().__init__(posterior, inducing_inputs, jitter) - self.variational_mean = Real( - variational_mean or jnp.zeros((self.num_inducing, 1)) - ) - self.variational_root_covariance = LowerTriangular( - variational_root_covariance or jnp.eye(self.num_inducing) - ) + if variational_mean is None: + variational_mean = jnp.zeros((self.num_inducing, 1)) + + if variational_root_covariance is None: + variational_root_covariance = jnp.eye(self.num_inducing) + + self.variational_mean = Real(variational_mean) + self.variational_root_covariance = LowerTriangular(variational_root_covariance) def prior_kl(self) -> ScalarFloat: r"""Compute the prior KL divergence. @@ -378,12 +380,14 @@ def __init__( ): super().__init__(posterior, inducing_inputs, jitter) - self.natural_vector = Static( - natural_vector or jnp.zeros((self.num_inducing, 1)) - ) - self.natural_matrix = Static( - natural_matrix or -0.5 * jnp.eye(self.num_inducing) - ) + if natural_vector is None: + natural_vector = jnp.zeros((self.num_inducing, 1)) + + if natural_matrix is None: + natural_matrix = -0.5 * jnp.eye(self.num_inducing) + + self.natural_vector = Static(natural_vector) + self.natural_matrix = Static(natural_matrix) def prior_kl(self) -> ScalarFloat: r"""Compute the KL-divergence between our current variational approximation @@ -540,13 +544,14 @@ def __init__( ): super().__init__(posterior, inducing_inputs, jitter) - # must come after super().__init__ - self.expectation_vector = Static( - expectation_vector or jnp.zeros((self.num_inducing, 1)) - ) - self.expectation_matrix = Static( - expectation_matrix or jnp.eye(self.num_inducing) - ) + if expectation_vector is None: + expectation_vector = jnp.zeros((self.num_inducing, 1)) + + if expectation_matrix is None: + expectation_matrix = jnp.eye(self.num_inducing) + + self.expectation_vector = Static(expectation_vector) + self.expectation_matrix = Static(expectation_matrix) def prior_kl(self) -> ScalarFloat: r"""Evaluate the prior KL-divergence.