diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index f917ec044..77f125b12 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -10,7 +10,7 @@ from jax import config config.update("jax_enable_x64", True) -from dataclasses import dataclass +from dataclasses import dataclass, field from jax import hessian from jax import config @@ -195,10 +195,16 @@ def dataset_3d(pos, vel): # %% + + @dataclass class VelocityKernel(gpx.kernels.AbstractKernel): - kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) - kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) + kernel0: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) + kernel1: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"] @@ -429,8 +435,12 @@ def plot_fields( @dataclass class HelmholtzKernel(gpx.kernels.AbstractKernel): # initialise Phi and Psi kernels as any stationary kernel in gpJax - potential_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) - stream_kernel: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]) + potential_kernel: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) + stream_kernel: gpx.kernels.AbstractKernel = field( + default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) + ) def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]