From 7c645c267387be91dc7f3c8f80d92eae1ce03a82 Mon Sep 17 00:00:00 2001 From: Stefano <46034160+stefanocortinovis@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:23:57 +0000 Subject: [PATCH] Allow to pass trainable inducing inputs to AbstractVariationalGaussian (#485) --- gpjax/variational_families.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index f1c40bdcc..88c5f7b3c 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]): def __init__( self, posterior: AbstractPosterior[P, L], - inducing_inputs: Float[Array, "N D"], + inducing_inputs: tp.Union[ + Float[Array, "N D"], + Real, + Static, + ], jitter: ScalarFloat = 1e-6, ): - self.inducing_inputs = Static(inducing_inputs) + if not isinstance(inducing_inputs, (Real, Static)): + inducing_inputs = Real(inducing_inputs) + + self.inducing_inputs = inducing_inputs self.jitter = jitter super().__init__(posterior)