diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 1716317db..eae8a8c02 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -23,16 +23,19 @@ # %% import typing as tp +import distrax as dx import jax import jax.numpy as jnp import jax.random as jr import jax.scipy.linalg as jsl import matplotlib.pyplot as plt -import distrax as dx import optax as ox +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 4c0f59bc7..4dc5e8a76 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -28,10 +28,13 @@ import jax.scipy as jsp import matplotlib.pyplot as plt import optax as ox +from jax.config import config from jaxtyping import Array, Float import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) I = jnp.eye key = jr.PRNGKey(123) diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index 473b0bfbd..eef6e8d39 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -25,9 +25,12 @@ import matplotlib.pyplot as plt import optax as ox from jax import jit +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] @@ -116,7 +119,7 @@ latent_dist = q.predict(D, learned_params)(xtest) predictive_dist = likelihood(latent_dist, learned_params) -samples = latent_dist.sample(seed=key, sample_shape=(20, )) +samples = latent_dist.sample(seed=key, sample_shape=(20,)) predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index 8b1e2c5da..27ac78fd7 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -28,9 +28,12 @@ import networkx as nx import optax as ox from jax import jit +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 002fe7636..10af39875 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -29,11 +29,14 @@ import matplotlib.pyplot as plt import optax as ox from chex import dataclass +from jax.config import config from scipy.signal import sawtooth import gpjax as gpx -from gpjax.kernels import Kernel, DenseKernelComputation +from gpjax.kernels import DenseKernelComputation, Kernel +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 902fcc588..02b8886bc 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -26,11 +26,14 @@ import jax.random as jr import matplotlib.pyplot as plt from jax import jit +from jax.config import config from jaxtyping import Array, Float from optax import adam import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index e7ee0818f..231ba338f 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -26,9 +26,12 @@ import matplotlib.pyplot as plt import optax as ox from jax import jit, lax +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/regression.pct.py b/examples/regression.pct.py index ecc817be3..6aaaa1c1f 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -26,9 +26,12 @@ import matplotlib.pyplot as plt import optax as ox from jax import jit +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) pp = PrettyPrinter(indent=4) key = jr.PRNGKey(123) diff --git a/examples/tfp_integration.pct.py b/examples/tfp_integration.pct.py index ad373ffe7..7076d80ae 100644 --- a/examples/tfp_integration.pct.py +++ b/examples/tfp_integration.pct.py @@ -25,10 +25,13 @@ import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt +from jax.config import config import gpjax as gpx from gpjax.utils import dict_array_coercion +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) pp = PrettyPrinter(indent=4) key = jr.PRNGKey(123) diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 1f6c6736b..3a04280e6 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -25,9 +25,12 @@ import matplotlib.pyplot as plt import optax as ox from jax import jit +from jax.config import config import gpjax as gpx +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) # %% [markdown] diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index 1d5eecdf1..a89039acc 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -19,7 +19,10 @@ import matplotlib.pyplot as plt import numpy as np import optax as ox +from jax.config import config +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) # %% [markdown] # # UCI Data Benchmarking # diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 2e83fc69d..34a33d1c4 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -13,11 +13,6 @@ # limitations under the License. # ============================================================================== -from jax.config import config - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - from .abstractions import fit, fit_batches, fit_natgrads from .gps import Prior, construct_posterior from .kernels import ( diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index af4dad932..c9ba5c389 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -17,12 +17,16 @@ import jax.random as jr import optax import pytest +from jax.config import config import gpjax as gpx from gpjax import RBF, Dataset, Gaussian, Prior, initialise from gpjax.abstractions import InferenceState, fit, fit_batches, fit_natgrads, get_batch from gpjax.parameters import ParameterState, build_bijectors +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + @pytest.mark.parametrize("n_iters", [1, 5]) @pytest.mark.parametrize("n", [1, 20]) diff --git a/tests/test_config.py b/tests/test_config.py index 887791050..a122b4d6e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,9 +13,14 @@ # limitations under the License. # ============================================================================== -from ml_collections import ConfigDict -from gpjax.config import add_parameter, get_defaults, Identity import distrax as dx +from jax.config import config +from ml_collections import ConfigDict + +from gpjax.config import Identity, add_parameter, get_defaults + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) def test_add_parameter(): diff --git a/tests/test_covariance_operator.py b/tests/test_covariance_operator.py index 80d2c5c6c..15b3aa59d 100644 --- a/tests/test_covariance_operator.py +++ b/tests/test_covariance_operator.py @@ -17,7 +17,10 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) from gpjax.covariance_operator import ( CovarianceOperator, DenseCovarianceOperator, diff --git a/tests/test_gps.py b/tests/test_gps.py index 7d2b37fd4..dad6c157a 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -15,10 +15,11 @@ import typing as tp +import distrax as dx import jax.numpy as jnp import jax.random as jr -import distrax as dx import pytest +from jax.config import config from gpjax import Dataset, initialise from gpjax.gps import ( @@ -32,6 +33,8 @@ from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ParameterState +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) NonConjugateLikelihoods = [Bernoulli] diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 691a74e8d..00d60d2eb 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -22,6 +22,7 @@ import networkx as nx import numpy as np import pytest +from jax.config import config from jaxtyping import Array, Float from gpjax.covariance_operator import ( @@ -47,6 +48,8 @@ from gpjax.parameters import initialise from gpjax.types import PRNGKeyType +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) """Default values for tests""" _initialise_key = jr.PRNGKey(123) _jitter = 100 diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index f09f36c70..3c6bf1a65 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -15,11 +15,12 @@ import typing as tp +import distrax as dx import jax.numpy as jnp -import numpy as np import jax.random as jr +import numpy as np import pytest -import distrax as dx +from jax.config import config from gpjax.likelihoods import ( AbstractLikelihood, @@ -30,6 +31,8 @@ ) from gpjax.parameters import initialise +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) true_initialisation = { "Gaussian": ["obs_noise"], "Bernoulli": [], diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 02d5e35f0..7d9e1bbc8 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -18,10 +18,14 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config from gpjax.mean_functions import Constant, Zero from gpjax.parameters import initialise +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + @pytest.mark.parametrize("meanf", [Zero, Constant]) @pytest.mark.parametrize("dim", [1, 2, 5]) diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index 6cc40dc8d..dbb348fcd 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config import gpjax as gpx from gpjax.abstractions import get_batch @@ -31,6 +32,8 @@ ) from gpjax.parameters import recursive_items +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) key = jr.PRNGKey(123) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 2080a0e9c..94ed622fe 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -15,10 +15,11 @@ import typing as tp -import jax.numpy as jnp import distrax as dx +import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config from gpjax.gps import Prior from gpjax.kernels import RBF @@ -38,6 +39,8 @@ unconstrain, ) +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) ######################### # Test base functionality diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index 5bed1a995..dbe1a03e1 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -17,9 +17,13 @@ import jax import jax.numpy as jnp import pytest +from jax.config import config from gpjax.quadrature import gauss_hermite_quadrature +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + @pytest.mark.parametrize("jit", [True, False]) def test_quadrature(jit): diff --git a/tests/test_types.py b/tests/test_types.py index 13b5fa04a..d4b11a940 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -15,9 +15,13 @@ import jax.numpy as jnp import pytest +from jax.config import config from gpjax.types import Dataset, NoneType, verify_dataset +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + def test_nonetype(): assert isinstance(None, NoneType) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index ebefdefa4..21fbd961e 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import pytest +from jax.config import config from gpjax.utils import ( concat_dictionaries, @@ -23,6 +24,9 @@ sort_dictionary, ) +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + def test_concat_dict(): d1 = {"a": 1, "b": 2} diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 0b14da758..d4fb22813 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config import gpjax as gpx from gpjax.variational_families import ( @@ -30,6 +31,9 @@ WhitenedVariationalGaussian, ) +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + def test_abstract_variational_family(): with pytest.raises(TypeError): diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py index fb79006ee..e310f275f 100644 --- a/tests/test_variational_inference.py +++ b/tests/test_variational_inference.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jax.config import config import gpjax as gpx from gpjax.variational_families import ( @@ -29,6 +30,9 @@ WhitenedVariationalGaussian, ) +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + def test_abstract_variational_inference(): prior = gpx.Prior(kernel=gpx.RBF())