Skip to content

Mean function and kernel with common/shared hyperparameters #474

Answered by theo-brown
theo-brown asked this question in Q&A
Discussion options

You must be logged in to vote

Looking at the source code, rather than the docs, the implemented kernels have an additional check in in their constructor:

        if isinstance(lengthscale, nnx.Variable):
            self.lengthscale = lengthscale
        else:
            self.lengthscale = PositiveReal(lengthscale)

which seems to get round this. If I implement this logic:

class MyKernel(gpx.kernels.AbstractKernel):
   def __init__(self, a: float | nnx.Variable, *args, **kwargs):
       super().__init__(*args, **kwargs)
       
       if isinstance(a, nnx.Variable):
           self.a = a 
       else:
           self.a = gpx.parameters.PositiveReal(jnp.array(a))

   def __call__(self, x1: jax.Array, x2: jax.Array) -> jax

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@thomaspinder
Comment options

@theo-brown
Comment options

@theo-brown
Comment options

@theo-brown
Comment options

Answer selected by theo-brown
@thomaspinder
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants