Skip to content

Commit

Permalink
Merge pull request #442 from Thomas-Christie/update-dependencies
Browse files Browse the repository at this point in the history
Update dependencies
  • Loading branch information
thomaspinder authored Mar 12, 2024
2 parents 23a85c3 + a4176e4 commit 50a4b96
Show file tree
Hide file tree
Showing 47 changed files with 2,117 additions and 2,066 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.PRNGKey(123)
key = jr.key(123)

f = lambda x: 10 * jnp.sin(x)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Kernels:
params = [[10, 100, 500, 1000, 2000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.uniform(
key=key, minval=-3.0, maxval=3.0, shape=(n_datapoints, n_dims)
)
Expand Down
32 changes: 0 additions & 32 deletions benchmarks/linops.py

This file was deleted.

12 changes: 6 additions & 6 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Gaussian:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.objective = gpx.objectives.ConjugateMLL()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand All @@ -42,15 +42,15 @@ class Bernoulli:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.where(jnp.sin(self.X[:, :1]) > 0, 1, 0)
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.objective = gpx.objectives.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand All @@ -68,7 +68,7 @@ class Poisson:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jr.poisson(key, jnp.exp(f(self.X)))
Expand All @@ -77,7 +77,7 @@ def setup(self, n_datapoints: int, n_dims: int):
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.objective = gpx.objectives.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Gaussian:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -39,7 +39,7 @@ class Bernoulli:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.y = jnp.array(jnp.where(self.y > 0, 1, 0), dtype=jnp.float64)
Expand All @@ -64,7 +64,7 @@ class Poisson:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jnp.array(jr.poisson(key, jnp.exp(f(self.X))), dtype=jnp.float64)
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Sparse:
params = [[2000, 5000, 10000, 20000], [10, 20, 50, 100, 200]]

def setup(self, n_datapoints: int, n_inducing: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, 1))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -24,10 +24,10 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.posterior = self.prior * self.likelihood

Z = jnp.linspace(self.X.min(), self.X.max(), n_inducing).reshape(-1, 1)
self.q = gpx.CollapsedVariationalGaussian(
self.q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=self.posterior, inducing_inputs=Z
)
self.objective = gpx.CollapsedELBO(negative=True)
self.objective = gpx.objectives.CollapsedELBO(negative=True)

def time_eval(self, n_datapoints: int, n_dims: int):
self.objective(self.q, self.data)
Expand Down
12 changes: 7 additions & 5 deletions benchmarks/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Sparse:
params = [[10000, 20000, 50000], [10, 20, 50, 100, 200], [32, 64, 128, 256]]

def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, 1))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -25,15 +25,17 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.posterior = self.prior * self.likelihood

Z = jnp.linspace(self.X.min(), self.X.max(), n_inducing).reshape(-1, 1)
self.q = gpx.VariationalGaussian(posterior=self.posterior, inducing_inputs=Z)
self.objective = gpx.ELBO(negative=True)
self.q = gpx.variational_families.VariationalGaussian(
posterior=self.posterior, inducing_inputs=Z
)
self.objective = gpx.objectives.ELBO(negative=True)

def time_eval(self, n_datapoints: int, n_dims: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
batch = get_batch(train_data=self.data, batch_size=batch_size, key=key)
self.objective(self.q, batch)

def time_grad(self, n_datapoints: int, n_dims: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
batch = get_batch(train_data=self.data, batch_size=batch_size, key=key)
jax.grad(self.objective)(self.q, batch)
2 changes: 1 addition & 1 deletion docs/_static/jaxkern/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# import gpjax.kernels as jk

# key = jr.PRNGKey(123)
# key = jr.key(123)


# def set_font(font_path):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import gpjax as gpx


key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gpjax.typing import Array, FunctionalSample, ScalarFloat
from jaxopt import ScipyBoundedMinimize

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

tfd = tfp.distributions
identity_matrix = jnp.eye
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import gpjax as gpx
from gpjax.base.param import param_field

key = jr.PRNGKey(123)
key = jr.key(123)
tfb = tfp.bijectors
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
Float,
)

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import AbstractKernelComputation

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down Expand Up @@ -103,7 +103,7 @@ class DeepKernelFunction(AbstractKernel):
base_kernel: AbstractKernel = None
network: nn.Module = static_field(None)
dummy_x: jax.Array = static_field(None)
key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123))
key: jax.Array = static_field(jr.key(123))
nn_params: Any = field(init=False, repr=False)

def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/intro_to_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
# determines the correlation of the multivariate Gaussian.

# %%
key = jr.PRNGKey(123)
key = jr.key(123)

d1 = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
d2 = tfd.MultivariateNormalTriL(
Expand Down
13 changes: 10 additions & 3 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gpjax.typing import Array
from sklearn.preprocessing import StandardScaler

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down Expand Up @@ -249,17 +249,24 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# with the optimised hyperparameters, and compare them to the predictions made using the
# posterior with the default hyperparameters:


# %%
def plot_ribbon(ax, x, dist, color):
mean = dist.mean()
std = dist.stddev()
ax.plot(x, mean, label="Predictive mean", color=color)
ax.fill_between(x.squeeze(), mean - 2 * std, mean + 2 * std, alpha=0.2, label="Two sigma", color=color)
ax.fill_between(
x.squeeze(),
mean - 2 * std,
mean + 2 * std,
alpha=0.2,
label="Two sigma",
color=color,
)
ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color)
ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color)



# %%
opt_latent_dist = opt_posterior.predict(test_x, train_data=D)
opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
key = jr.PRNGKey(123)
key = jr.key(123)


n = 50
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
tfd = tfp.distributions
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ class RBF(Module):
init=False, bijector=tfb.Softplus(), trainable=True
)
variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True)
key: jr.KeyArray = field(default_factory = lambda: jr.PRNGKey(42))
key: jax.Array = field(default_factory = lambda: jr.key(42))
# Note, for Python <3.11 you may use the following:
# key: jr.KeyArray = jr.PRNGKey(42)
# key: jax.Array = jr.key(42)

def __post_init__(self):
# Split key into two keys
Expand Down Expand Up @@ -444,7 +444,7 @@ class RBF(Module):
init=False, bijector=tfb.Softplus(), trainable=True
)
variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True)
key: jr.KeyArray = static_field(default_factory=lambda: jr.PRNGKey(42))
key: jax.Array = static_field(default_factory=lambda: jr.key(42))

def __post_init__(self):
# Split key into two keys
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
Loading

0 comments on commit 50a4b96

Please sign in to comment.