diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index a07222e17..f64c7b382 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # %% [markdown] # # Graph Kernels # @@ -23,6 +22,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt import networkx as nx +import optax as ox with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx @@ -154,11 +154,14 @@ # With a posterior defined, we can now optimise the model's hyperparameters. # %% -opt_posterior, training_history = gpx.fit_scipy( +opt_posterior, training_history = gpx.fit( model=posterior, - objective=jit(gpx.ConjugateMLL(negative=True)), + objective=gpx.ConjugateMLL(negative=True), train_data=D, -) + optim=ox.adam(learning_rate=0.01), + num_iters=1000, + key=key + ) # %% [markdown] # diff --git a/docs/examples/regression_mo.py b/docs/examples/regression_mo.py deleted file mode 100644 index 7ef0f43e0..000000000 --- a/docs/examples/regression_mo.py +++ /dev/null @@ -1,308 +0,0 @@ -# --- -# jupyter: -# jupytext: -# cell_metadata_filter: -all -# custom_cell_magics: kql -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.11.2 -# kernelspec: -# display_name: gpjax -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Regression with multiple outputs (EXPERIMENTAL) -# -# In this notebook we demonstate how to fit a Gaussian process regression model with multiple correlated outputs. -# This feature is still experimental. - -# %% -# Enable Float64 for more stable matrix inversions. -from jax import config - -config.update("jax_enable_x64", True) - -from jax import jit -import jax.numpy as jnp -import jax.random as jr -from jaxtyping import install_import_hook -import matplotlib as mpl -import matplotlib.pyplot as plt -import optax as ox -from docs.examples.utils import clean_legend - -# with install_import_hook("gpjax", "beartype.beartype"): -import gpjax as gpx - -key = jr.PRNGKey(123) -# plt.style.use( -# "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -# ) -cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] - -# %% [markdown] -# ## Dataset -# -# With the necessary modules imported, we simulate a dataset -# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs $\boldsymbol{x}$ -# sampled uniformly on $(-3., 3)$ and corresponding independent noisy outputs -# -# $$\boldsymbol{y} \sim \mathcal{N} \left(\left[\sin(4\boldsymbol{x}) + \cos(2 \boldsymbol{x}), \sin(4\boldsymbol{x}) + \cos(3 \boldsymbol{x})\right], \textbf{I} * 0.3^2 \right).$$ -# -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels -# for later. - -# %% -n = 100 -noise = 0.3 - -key, subkey = jr.split(key) -x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1) -f = lambda x: jnp.sin(4 * x) + jnp.array([jnp.cos(2 * x), jnp.cos(3 * x)]).T.squeeze() -signal = f(x) -y = signal + jr.normal(subkey, shape=signal.shape) * noise - -D = gpx.Dataset(X=x, y=y) - -xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1) -ytest = f(xtest) - -# %% [markdown] -# To better understand what we have simulated, we plot both the underlying latent -# function and the observed data that is subject to Gaussian noise. - -# %% -fig, ax = plt.subplots(nrows=2, figsize=(7.5, 5)) -for i in range(2): - ax[i].plot(x, y[:, i], "x", label="Observations", color=cols[0]) - ax[i].plot(xtest, ytest[:, i], "--", label="Latent function", color=cols[1]) - ax[i].legend(loc="best") - -# %% [markdown] -# Our aim in this tutorial will be to reconstruct the latent function from our noisy -# observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a -# Gaussian process prior in the next section. -# -# ## Defining the prior -# -# A zero-mean Gaussian process (GP) places a prior distribution over real-valued -# functions $f(\cdot)$ where -# $f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})$ -# for any finite collection of inputs $\boldsymbol{x}$. -# -# Here $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ is the Gram matrix generated by a -# user-specified symmetric, non-negative definite kernel function $k(\cdot, \cdot')$ -# with $[\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}]_{i, j} = k(x_i, x_j)$. -# The choice of kernel function is critical as, among other things, it governs the -# smoothness of the outputs that our GP can generate. -# -# For simplicity, we consider a radial basis function (RBF) kernel to model similarity of inputs: -# $$k_\mathrm{inp}(x_\mathrm{inp}, x_\mathrm{inp}') = \sigma^2 \exp\left(-\frac{\lVert x - x' \rVert_2^2}{2 \ell^2}\right),$$ -# and a categorical kernel to model similarity of outputs: -# $$k_\mathrm{idx}(x_\mathrm{idx}, x_\mathrm{idx}') = G_{x_\mathrm{idx}, x_\mathrm{idx}'}.$$ -# Here, $G$ is an explicit gram matrix and $x_\mathrm{idx}, x_\mathrm{idx}'$ are indices to the output dimension and to $G$. -# For example $G_{1,2}$ contains the covariance between output dimensions $1$ and $2$, as does $G_{2,1} = G_{1,2}$. -# -# The overall kernel then is defined as -# $$k([x_\mathrm{inp}, x_\mathrm{idx}], [x_\mathrm{inp}', x_\mathrm{idx}']) = k_\mathrm{inp}(x_\mathrm{inp}, x_\mathrm{inp}') k_\mathrm{idx}(x_\mathrm{idx}, x_\mathrm{idx}').$$ -# In the standard GPJax implementation, we never explicitly handle output dimension indices such as $x_\mathrm{idx}$. -# Rather, we simply define a dataset with multiple output columns. -# -# On paper a GP is written as $f(\cdot) \sim \mathcal{GP}(\textbf{0}, k(\cdot, \cdot'))$, -# we can reciprocate this process in GPJax via defining a `Prior` with our chosen `RBF` -# kernel. - -# %% -kernel = gpx.kernels.RBF() -catkernel_params = gpx.kernels.CatKernel.gram_to_stddev_cholesky_lower(jnp.eye(2)) -out_kernel = gpx.kernels.CatKernel( - stddev=catkernel_params.stddev, cholesky_lower=catkernel_params.cholesky_lower -) -# out_kernel = gpx.kernels.White(variance=1.0) -meanf = gpx.mean_functions.Constant(jnp.array([0.0, 1.0])) -prior = gpx.Prior(mean_function=meanf, kernel=kernel, out_kernel=out_kernel) - -# %% [markdown] -# -# The above construction forms the foundation for GPJax's models. Moreover, the GP prior -# we have just defined can be represented by a -# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax) -# multivariate Gaussian distribution. Such functionality enables trivial sampling, and -# the evaluation of the GP's mean and covariance . - -# %% -prior_dist = prior.predict(xtest) - -prior_mean = prior_dist.mean() -prior_std = prior_dist.variance() -samples = prior_dist.sample(seed=key, sample_shape=(20,)) - - -fig, ax = plt.subplots(nrows=2, figsize=(7.5, 5)) -for i in range(2): - ax[i].plot(xtest, samples.T[i], alpha=0.5, color=cols[0], label="Prior samples") - ax[i].plot(xtest, prior_mean[:, i], color=cols[1], label="Prior mean") - ax[i].fill_between( - xtest.flatten(), - prior_mean[:, i] - prior_std[:, i], - prior_mean[:, i] + prior_std[:, i], - alpha=0.3, - color=cols[1], - label="Prior variance", - ) - ax[i].legend(loc="best") - ax[i] = clean_legend(ax[i]) - -# %% [markdown] -# ## Constructing the posterior -# -# Having defined our GP, we proceed to define a description of our data -# $\mathcal{D}$ conditional on our knowledge of $f(\cdot)$ --- this is exactly the -# notion of a likelihood function $p(\mathcal{D} | f(\cdot))$. While the choice of -# likelihood is a critical in Bayesian modelling, for simplicity we consider a -# Gaussian with noise parameter $\alpha$ -# $$p(\mathcal{D} | f(\cdot)) = \mathcal{N}(\boldsymbol{y}; f(\boldsymbol{x}), \textbf{I} \alpha^2).$$ -# This is defined in GPJax through calling a `Gaussian` instance. - -# %% -likelihood = gpx.Gaussian(num_datapoints=D.n) - -# %% [markdown] -# The posterior is proportional to the prior multiplied by the likelihood, written as -# -# $$ p(f(\cdot) | \mathcal{D}) \propto p(f(\cdot)) * p(\mathcal{D} | f(\cdot)). $$ -# -# Mimicking this construct, the posterior is established in GPJax through the `*` operator. - -# %% -posterior = prior * likelihood - -# %% [markdown] -# -# -# ## Parameter state -# -# As outlined in the [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) -# documentation, parameters are contained within the model and for the leaves of the -# PyTree. Consequently, in this particular model, we have three parameters: the -# kernel lengthscale, kernel variance and the observation noise variance. Whilst -# we have initialised each of these to 1, we can learn Type 2 MLEs for each of -# these parameters by optimising the marginal log-likelihood (MLL). - -# %% -negative_mll = gpx.objectives.ConjugateMLL(negative=True) -negative_mll(posterior, train_data=D) - - -# static_tree = jax.tree_map(lambda x: not(x), posterior.trainables) -# optim = ox.chain( -# ox.adam(learning_rate=0.01), -# ox.masked(ox.set_to_zero(), static_tree) -# ) -# %% [markdown] -# For researchers, GPJax has the capacity to print the bibtex citation for objects such -# as the marginal log-likelihood through the `cite()` function. - -# %% -print(gpx.cite(negative_mll)) - -# %% [markdown] -# JIT-compiling expensive-to-compute functions such as the marginal log-likelihood is -# advisable. This can be achieved by wrapping the function in `jax.jit()`. - -# %% -negative_mll = jit(negative_mll) - -# %% [markdown] -# Since most optimisers (including here) minimise a given function, we have realised -# the negative -# marginal log-likelihood and just-in-time (JIT) compiled this to -# accelerate training. - -# %% [markdown] -# We can now define an optimiser with `scipy`. For this example we'll use the `BFGS` -# optimiser. - -# %% -opt_posterior, history = gpx.fit_scipy( - model=posterior, - objective=negative_mll, - train_data=D, -) - -# %% [markdown] -# ## Prediction -# -# Equipped with the posterior and a set of optimised hyperparameter values, we are now -# in a position to query our GP's predictive distribution at novel test inputs. To do -# this, we use our defined `posterior` and `likelihood` at our test inputs to obtain -# the predictive distribution as a multivariate Gaussian upon which `mean` -# and `stddev` can be used to extract the predictive mean and standard deviatation. - -# %% -latent_dist = opt_posterior.predict(xtest, train_data=D) -predictive_dist = opt_posterior.likelihood(latent_dist) - -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() - -# %% [markdown] -# With the predictions and their uncertainty acquired, we illustrate the GP's -# performance at explaining the data $\mathcal{D}$ and recovering the underlying -# latent function of interest. - -# %% - -fig, ax = plt.subplots(nrows=2, figsize=(7.5, 5)) -for i in range(2): - ax[i].plot(x, y[:, i], "x", label="Observations", color=cols[0], alpha=0.5) - - ax[i].fill_between( - xtest.squeeze(), - predictive_mean[:, i] - 2 * predictive_std[:, i], - predictive_mean[:, i] + 2 * predictive_std[:, i], - alpha=0.2, - label="Two sigma", - color=cols[1], - ) - ax[i].plot( - xtest, - predictive_mean[:, i] - 2 * predictive_std[:, i], - linestyle="--", - linewidth=1, - color=cols[1], - ) - ax[i].plot( - xtest, - predictive_mean[:, i] + 2 * predictive_std[:, i], - linestyle="--", - linewidth=1, - color=cols[1], - ) - ax[i].plot( - xtest, - ytest[:, i], - label="Latent function", - color=cols[0], - linestyle="--", - linewidth=2, - ) - ax[i].plot(xtest, predictive_mean[:, i], label="Predictive mean", color=cols[1]) - ax[i].legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) - -# %% [markdown] -# ## System configuration - -# %% -# %reload_ext watermark -# %watermark -n -u -v -iv -w -a 'Thomas Pinder & Daniel Dodd' - -# %% diff --git a/docs/examples/spatial.py b/docs/examples/spatial.py deleted file mode 100644 index ba0e7b109..000000000 --- a/docs/examples/spatial.py +++ /dev/null @@ -1,287 +0,0 @@ -# -*- coding: utf-8 -*- -# --- -# jupyter: -# jupytext: -# cell_metadata_filter: -all -# custom_cell_magics: kql -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.11.2 -# kernelspec: -# display_name: gpjax_beartype -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Pathwise Sampling for Spatial Modelling -# In this notebook, we demonstrate an application of Gaussian Processes -# to a spatial interpolation problem. We will show how -# to efficiently sample from a GP posterior as shown in . -# -# ## Data loading -# We'll use open-source data from -# [SwissMetNet](https://www.meteoswiss.admin.ch/services-and-publications/applications/measurement-values-and-measuring-networks.html#lang=en¶m=messnetz-automatisch), -# the surface weather monitoring network of the Swiss national weather service, -# and digital elevation model (DEM) data from Copernicus, accessible -# [here](https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-90) -# via the Planetary Computer data catalog. -# We will coarsen this data by a factor of 10 (going from 90m to 900m resolution), but feel free to change this. -# -# Our variable of interest is the maximum daily temperature, observed on the 4th of April 2023 at -# 150 weather stations, and we'll try to interpolate it on a spatial grid using geographical coordinates -# (latitude and longitude) and elevation as input variables. -# -# %% -# Enable Float64 for more stable matrix inversions. -from jax import config - -config.update("jax_enable_x64", True) - -from dataclasses import dataclass - -import fsspec -import geopandas as gpd -import jax -import jax.numpy as jnp -import jax.random as jr -from jaxtyping import ( - Array, - Float, - install_import_hook, -) -import matplotlib as mpl -import matplotlib.pyplot as plt -import optax as ox -import pandas as pd -import planetary_computer -import pystac_client -import rioxarray as rio -from rioxarray.merge import merge_arrays -import xarray as xr - -with install_import_hook("gpjax", "beartype.beartype"): - import gpjax as gpx - from gpjax.base import param_field - from gpjax.dataset import Dataset - - -key = jr.PRNGKey(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) -cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] - -# Observed temperature data -try: - temperature = pd.read_csv("data/max_tempeature_switzerland.csv") -except FileNotFoundError: - temperature = pd.read_csv("docs/examples/data/max_tempeature_switzerland.csv") - -temperature = gpd.GeoDataFrame( - temperature, - geometry=gpd.points_from_xy(temperature.longitude, temperature.latitude), -).dropna(how="any") - -# Country borders shapefile -path = "simplecache::https://www.naturalearthdata.com/http//www.naturalearthdata.com/download/10m/cultural/ne_10m_admin_0_countries.zip" -with fsspec.open(path) as file: - ch_shp = gpd.read_file(file).query("ADMIN == 'Switzerland'") - - -# Read DEM data and clip it to switzerland -catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1", - modifier=planetary_computer.sign_inplace, -) -search = catalog.search(collections=["cop-dem-glo-90"], bbox=[5.5, 45.5, 10.0, 48.5]) -items = list(search.get_all_items()) -tiles = [rio.open_rasterio(i.assets["data"].href).squeeze().drop("band") for i in items] -dem = merge_arrays(tiles).coarsen(x=10, y=10).mean().rio.clip(ch_shp["geometry"]) - -# %% [markdown] -# Let us take a look at the data. The topography of Switzerland is quite complex, and there -# are sometimes very large height differences over short distances. This measuring network is fairly dense, -# and you may already notice that there's a dependency between maximum daily temperature and elevation. -# %% -fig, ax = plt.subplots(figsize=(8, 5), layout="constrained") -dem.plot( - cmap="terrain", cbar_kwargs={"aspect": 50, "pad": 0.02, "label": "Elevation [m]"} -) -temperature.plot("t_max", ax=ax, cmap="RdBu_r", vmin=-15, vmax=15, edgecolor="k", s=50) -ax.set(title="Switzerland's topography and SwissMetNet stations", aspect="auto") -cb = fig.colorbar(ax.collections[-1], aspect=50, pad=0.02) -cb.set_label("Max. daily temperature [°C]", labelpad=-2) - - -# %% [markdown] -# As always, we store our training data in a `Dataset` object. -# %% -x = temperature[["latitude", "longitude", "elevation"]].values -y = temperature[["t_max"]].values -D = Dataset( - X=jnp.array(x), - y=jnp.array(y), -) - -# %% [markdown] -# ## ARD Kernel -# As temperature decreases with height -# (at a rate of approximately -6.5 °C/km in average conditions), we can expect that using the geographical distance -# alone isn't enough to to a decent job at interpolating this data. Therefore, we can also use elevation and optimize -# the parameters of our kernel such that more relevance should be given to elevation. This is possible by using a -# kernel that has one length-scale parameter per input dimension: an automatic relevance determination (ARD) kernel. -# See our [kernel notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) for more an introduction to -# kernels in GPJax. - -# %% -kernel = gpx.kernels.RBF( - active_dims=[0, 1, 2], - lengthscale=jnp.array([0.1, 0.1, 100.0]), -) - -# %% [markdown] -# ## Mean function -# As stated before, we already know that temperature strongly depends on elevation. -# So why not use it for our mean function? GPJax lets you define custom mean functions; -# simply subclass `AbstractMeanFunction`. - - -# %% -@dataclass -class MeanFunction(gpx.gps.AbstractMeanFunction): - w: Float[Array, "1"] = param_field(jnp.array([0.0])) - b: Float[Array, "1"] = param_field(jnp.array([0.0])) - - def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: - elevation = x[:, 2:3] - out = elevation * self.w + self.b - return out - - -# %% [markdown] -# Now we can define our prior. We'll also choose a Gaussian likelihood. - -# %% -mean_function = MeanFunction() -prior = gpx.Prior(kernel=kernel, mean_function=mean_function) -likelihood = gpx.Gaussian(D.n) - -# %% [markdown] -# Finally, we construct the posterior. -# %% -posterior = prior * likelihood - - -# %% [markdown] -# ## Model fitting -# We proceed to train our model. Because we used a Gaussian likelihood, the resulting posterior is -# a `ConjugatePosterior`, which allows us to optimize the analytically expressed marginal loglikelihood. -# -# As always, we can jit-compile the objective function to speed things up. -# %% -negative_mll = jax.jit(gpx.objectives.ConjugateMLL(negative=True)) -negative_mll(posterior, train_data=D) - -# %% -optim = ox.chain(ox.adam(learning_rate=0.1), ox.clip(1.0)) -posterior, history = gpx.fit( - model=posterior, - objective=negative_mll, - train_data=D, - optim=optim, - num_iters=3000, - safe=True, - key=key, -) -posterior: gpx.gps.ConjugatePosterior -# %% [markdown] -# ## Sampling on a grid -# Now comes the cool part. In a standard GP implementation, for n test points, we have a $\mathcal{O}(n^2)$ -# computational complexity and $\mathcal{O}(n^2)$ memory requirement. We want to make predictions on a total -# of roughly 70'000 pixels, and that would require us to compute a covariance matrix of `70000 ** 2 = 4900000000` elements. -# If these are `float64`s, as it is often the case in GPJax, it would be equivalent to more than 36 Gigabytes of memory. And -# that's for a fairly coarse and tiny grid. If we were to make predictions on a 1000x1000 grid, the total memory required -# would be 8 _Terabytes_ of memory, which is intractable. -# Fortunately, the pathwise conditioning method allows us to sample from our posterior in linear complexity, -# $\mathcal{O}(n)$, with the number of pixels. -# -# GPJax provides the `sample_approx` method to generate random conditioned samples from our posterior. - -# %% -# select the target pixels and exclude nans -xtest = dem.drop("spatial_ref").stack(p=["y", "x"]).to_dataframe(name="dem") -mask = jnp.any(jnp.isnan(xtest.values), axis=-1) - -# generate 50 samples -ytest = posterior.sample_approx(50, D, key, num_features=200)( - jnp.array(xtest.values[~mask]) -) -# %% [markdown] -# Let's take a look at the results. We start with the mean and standard deviation. - -# %% -predtest = xr.zeros_like(dem.stack(p=["y", "x"])) * jnp.nan -predtest[~mask] = ytest.mean(axis=-1) -predtest = predtest.unstack() - -predtest.plot( - vmin=-15.0, - vmax=15.0, - cmap="RdBu_r", - cbar_kwargs={"aspect": 50, "pad": 0.02, "label": "Max. daily temperature [°C]"}, -) -plt.gca().set_title("Interpolated maximum daily temperature") -# %% -predtest = xr.zeros_like(dem.stack(p=["y", "x"])) * jnp.nan -predtest[~mask] = ytest.std(axis=-1) -predtest = predtest.unstack() - -# plot -predtest.plot( - cbar_kwargs={"aspect": 50, "pad": 0.02, "label": "Standard deviation [°C]"}, -) -plt.gca().set_title("Standard deviation") -# %% [markdown] -# And now some individual realizations of our GP posterior. -# %% -predtest = ( - xr.zeros_like(dem.stack(p=["y", "x"])) - .expand_dims(realization=range(9)) - .transpose("p", "realization") - .copy() -) -predtest[~mask] = ytest[:, :9] -predtest = predtest.unstack() -predtest.plot( - col="realization", - col_wrap=3, - cbar_kwargs={"aspect": 50, "pad": 0.02, "label": "Max. daily temperature [°C]"}, -) -# %% [markdown] -# Remember when we said that on average the temperature decreases with height at a rate -# of approximately -6.5°C/km? That's -0.0065°C/m. The `w` parameter of our mean function -# is very close: we have learned the environmental lapse rate! - -# %% -print(posterior.prior.mean_function) -# %% [markdown] -# That's it! We've successfully interpolated an observed meteorological parameter on a grid. -# We have used several components of GPJax and adapted them to our needs: a custom mean function -# that modelled the average temperature lapse rate; an ARD kernel that learned to give more relevance -# to elevation rather than horizontal distance; an efficient sampling technique to produce -# probabilistic realizations of our posterior on a large number of test points, which is important for -# many spatiotemporal modelling applications. -# If you're interested in a more elaborate work on temperature interpolation for the same domain used here, refer -# to [Frei 2014](https://rmets.onlinelibrary.wiley.com/doi/full/10.1002/joc.3786). - -# %% [markdown] -# ## System configuration - -# %% -# %reload_ext watermark -# %watermark -n -u -v -iv -w -a 'Francesco Zanetta' - -# %% diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 27cbf986b..6e45ae72d 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -35,7 +35,6 @@ RFF, AbstractKernel, BasisFunctionComputation, - CatKernel, ConstantDiagonalKernelComputation, DenseKernelComputation, DiagonalKernelComputation, @@ -123,7 +122,6 @@ "CollapsedELBO", "ELBO", "AbstractKernel", - "CatKernel", "Linear", "DenseKernelComputation", "DiagonalKernelComputation", diff --git a/gpjax/base/module.py b/gpjax/base/module.py index 3c4419007..3d381d3c7 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -308,7 +308,7 @@ def _unpack_metadata( yield meta_leaf return - for metadata, leaf in zip(leaves_meta, leaves_values): + for metadata, leaf in zip(leaves_meta, leaves_values, strict=True): yield from _unpack_metadata((metadata, leaf), leaf, is_leaf) return list(_unpack_metadata(pytree, pytree, is_leaf)) @@ -352,7 +352,7 @@ def meta_map( """ leaves, treedef = meta_flatten(pytree, is_leaf=is_leaf) all_leaves = [leaves] + [treedef.treedef.flatten_up_to(r) for r in rest] - return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) + return treedef.unflatten(f(*xs) for xs in zip(*all_leaves, strict=True)) def meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module: diff --git a/gpjax/dataset.py b/gpjax/dataset.py index b77a3a37e..5fcc71baf 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -16,16 +16,9 @@ from dataclasses import dataclass import warnings -from beartype.typing import ( - Literal, - Optional, - Union, -) +from beartype.typing import Optional import jax.numpy as jnp -from jaxtyping import ( - Bool, - Num, -) +from jaxtyping import Num from simple_pytree import Pytree from gpjax.typing import Array @@ -39,17 +32,10 @@ class Dataset(Pytree): ---------- X (Optional[Num[Array, "N D"]]): input data. y (Optional[Num[Array, "N Q"]]): output data. - mask (Optional[Union[Bool[Array, "N Q"], Literal["infer automatically"]]]): mask for the output data. - Users can optionally specify a pre-computed mask, or explicitly pass `None` which - means no mask will be used. Defaults to `"infer automatically"`, which means that - the mask will be computed from the output data, or set to `None` if no output data is provided. """ X: Optional[Num[Array, "N D"]] = None y: Optional[Num[Array, "N Q"]] = None - mask: Optional[ - Union[Bool[Array, "N Q"], Literal["infer automatically"]] - ] = "infer automatically" def __post_init__(self) -> None: r"""Checks that the shapes of $`X`$ and $`y`$ are compatible, @@ -57,27 +43,9 @@ def __post_init__(self) -> None: _check_shape(self.X, self.y) _check_precision(self.X, self.y) - if isinstance(self.mask, str): - if not self.mask == "infer automatically": - raise ValueError( - f"mask must be either the string 'infer automatically', None, or a boolean array." - f" Got mask={self.mask}." - ) - elif self.y is not None: - mask = jnp.isnan(self.y) - if jnp.any(mask): - self.mask = mask - else: - self.mask = None - else: - self.mask = None - def __repr__(self) -> str: r"""Returns a string representation of the dataset.""" - repr = ( - f"- Number of observations: {self.n}\n- Input dimension:" - f" {self.in_dim}\n- Output dimension: {self.out_dim}" - ) + repr = f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}" return repr def is_supervised(self) -> bool: @@ -92,7 +60,6 @@ def __add__(self, other: "Dataset") -> "Dataset": r"""Combine two datasets. Right hand dataset is stacked beneath the left.""" X = None y = None - mask = None if self.X is not None and other.X is not None: X = jnp.concatenate((self.X, other.X)) @@ -100,14 +67,7 @@ def __add__(self, other: "Dataset") -> "Dataset": if self.y is not None and other.y is not None: y = jnp.concatenate((self.y, other.y)) - self_m_exists = self.mask is not None - other_m_exists = other.mask is not None - self_m = self.mask if self_m_exists else jnp.zeros(self.y.shape, dtype=bool) - other_m = other.mask if other_m_exists else jnp.zeros(other.y.shape, dtype=bool) - if self_m_exists or other_m_exists: - mask = jnp.concatenate((self_m, other_m)) - - return Dataset(X=X, y=y, mask=mask) + return Dataset(X=X, y=y) @property def n(self) -> int: @@ -119,11 +79,6 @@ def in_dim(self) -> int: r"""Dimension of the inputs, $`X`$.""" return self.X.shape[1] - @property - def out_dim(self) -> int: - r"""Dimension of the outputs, $`y`$.""" - return self.y.shape[1] - def _check_shape( X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]] diff --git a/gpjax/distributions.py b/gpjax/distributions.py index f638cef04..24bf0a601 100644 --- a/gpjax/distributions.py +++ b/gpjax/distributions.py @@ -16,25 +16,19 @@ from beartype.typing import ( Any, - Generic, Optional, Tuple, TypeVar, - Union, ) import cola from cola.ops import ( - Dense, Identity, LinearOperator, ) from jax import vmap import jax.numpy as jnp import jax.random as jr -from jaxtyping import ( - Bool, - Float, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax as tfp from gpjax.lower_cholesky import lower_cholesky @@ -161,14 +155,11 @@ def entropy(self) -> ScalarFloat: + cola.logdet(self.scale, Cholesky(), Cholesky()) ) - def log_prob( - self, y: Float[Array, " N"], mask: Optional[Bool[Array, " N"]] = None - ) -> ScalarFloat: + def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat: r"""Calculates the log pdf of the multivariate Gaussian. Args: y (Optional[Float[Array, " N"]]): the value of which to calculate the log probability. - mask: (Optional[Bool[Array, " N"]]): the mask for missing values in y. Returns ------- @@ -178,14 +169,6 @@ def log_prob( sigma = self.scale n = mu.shape[-1] - if mask is not None: - y = jnp.where(mask, 0.0, y) - mu = jnp.where(mask, 0.0, mu) - sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.to_dense()) - sigma = cola.PSD( - Dense(jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked)) - ) - # diff, y - µ diff = y - mu @@ -233,68 +216,6 @@ def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat: DistrT = TypeVar("DistrT", bound=tfd.Distribution) -class ReshapedDistribution(tfd.Distribution, Generic[DistrT]): - def __init__(self, distribution: tfd.Distribution, output_shape: Tuple[int, ...]): - self._distribution = distribution - self._output_shape = output_shape - - def mean(self) -> Float[Array, " N ..."]: - r"""Mean of the base distribution, reshaped to the output shape.""" - return jnp.reshape(self._distribution.mean(), self._output_shape) - - def median(self) -> Float[Array, " N ..."]: - r"""Median of the base distribution, reshaped to the output shape""" - return jnp.reshape(self._distribution.median(), self._output_shape) - - def mode(self) -> Float[Array, " N ..."]: - r"""Mode of the base distribution, reshaped to the output shape""" - return jnp.reshape(self._distribution.mode(), self._output_shape) - - def covariance(self) -> Float[Array, " N ..."]: - r"""Covariance of the base distribution, reshaped to the squared output shape""" - return jnp.reshape( - self._distribution.covariance(), self._output_shape + self._output_shape - ) - - def variance(self) -> Float[Array, " N ..."]: - r"""Variances of the base distribution, reshaped to the output shape""" - return jnp.reshape(self._distribution.variance(), self._output_shape) - - def stddev(self) -> Float[Array, " N ..."]: - r"""Standard deviations of the base distribution, reshaped to the output shape""" - return jnp.reshape(self._distribution.stddev(), self._output_shape) - - def entropy(self) -> ScalarFloat: - r"""Entropy of the base distribution.""" - return self._distribution.entropy() - - def log_prob( - self, y: Float[Array, " N ..."], mask: Optional[Bool[Array, " N ..."]] - ) -> ScalarFloat: - r"""Calculates the log probability.""" - return self._distribution.log_prob( - y.reshape(-1), mask if mask is None else mask.reshape(-1) - ) - - def sample( - self, seed: Any, sample_shape: Tuple[int, ...] = () - ) -> Float[Array, " n N ..."]: - r"""Draws samples from the distribution and reshapes them to the output shape.""" - sample = self._distribution.sample(seed, sample_shape) - return jnp.reshape(sample, sample_shape + self._output_shape) - - def kl_divergence(self, other: "ReshapedDistribution") -> ScalarFloat: - r"""Calculates the Kullback-Leibler divergence.""" - other_flat = tfd.Distribution( - loc=other._distribution.loc, scale=other._distribution.scale - ) - return tfd.kl_divergence(self._distribution, other_flat) - - @property - def event_shape(self) -> Tuple: - return self._output_shape - - def _check_and_return_dimension( q: GaussianDistribution, p: GaussianDistribution ) -> int: @@ -364,11 +285,6 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl ) / 2.0 -ReshapedGaussianDistribution = Union[ - GaussianDistribution, ReshapedDistribution[GaussianDistribution] -] __all__ = [ "GaussianDistribution", - "ReshapedDistribution", - "ReshapedGaussianDistribution", ] diff --git a/gpjax/fit.py b/gpjax/fit.py index 5e91eec44..6986549bc 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -254,12 +254,12 @@ def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: ------- Dataset: The batched dataset. """ - x, y, n, mask = train_data.X, train_data.y, train_data.n, train_data.mask + x, y, n = train_data.X, train_data.y, train_data.n # Subsample mini-batch indices with replacement. indices = jr.choice(key, n, (batch_size,), replace=True) - return Dataset(X=x[indices], y=y[indices], mask=mask[indices] if mask else None) + return Dataset(X=x[indices], y=y[indices]) def _check_model(model: Any) -> None: diff --git a/gpjax/gps.py b/gpjax/gps.py index 638f408cd..5928ef491 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -14,10 +14,7 @@ # ============================================================================== from abc import abstractmethod -from dataclasses import ( - dataclass, - field, -) +from dataclasses import dataclass from typing import overload from beartype.typing import ( @@ -28,7 +25,6 @@ ) import cola from cola.linalg.decompositions.decompositions import Cholesky -from cola.ops import Dense import jax.numpy as jnp from jax.random import ( PRNGKey, @@ -45,15 +41,8 @@ static_field, ) from gpjax.dataset import Dataset -from gpjax.distributions import ( - GaussianDistribution, - ReshapedDistribution, - ReshapedGaussianDistribution, -) -from gpjax.kernels import ( - RFF, - White, -) +from gpjax.distributions import GaussianDistribution +from gpjax.kernels import RFF from gpjax.kernels.base import AbstractKernel from gpjax.likelihoods import ( AbstractLikelihood, @@ -77,12 +66,7 @@ class AbstractPrior(Module): mean_function: AbstractMeanFunction jitter: float = static_field(1e-6) - # TODO: when letting kernels be responsible for certain features, like - # RBF(features=["outp_idx"]), this can be folded into the kernel, - # just not sure how to ensure Kronecker structure then - out_kernel: AbstractKernel = field(default_factory=White) - - def __call__(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: + def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: r"""Evaluate the Gaussian process at the given points. The output of this function is a @@ -100,13 +84,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: Returns ------- - ReshapedGaussianDistribution: A multivariate normal random variable representation - of the Gaussian process, possibly with reshaped events. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: r"""Evaluate the predictive distribution. Compute the latent function's multivariate normal distribution for a @@ -119,8 +103,8 @@ def predict(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: Returns ------- - ReshapedGaussianDistribution: A multivariate normal random variable representation - of the Gaussian process, possibly with reshaped events. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -229,7 +213,7 @@ def __rmul__(self, other): """ return self.__mul__(other) - def predict(self, test_inputs: Num[Array, "N D"]) -> ReshapedGaussianDistribution: + def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: r"""Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a TFP distribution for a given set of inputs. @@ -255,21 +239,16 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> ReshapedGaussianDistributio Returns ------- - ReshapedGaussianDistribution: A multivariate normal random variable representation - of the Gaussian process, possibly with reshaped events. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ x = test_inputs - mx = jnp.atleast_1d(self.mean_function(x)) + mx = self.mean_function(x) Kxx = self.kernel.gram(x) - Kyy = self.out_kernel.gram(jnp.arange(mx.shape[1])[:, jnp.newaxis]) - Sigma = cola.ops.Kronecker(Kxx, Kyy) - Sigma += cola.ops.I_like(Sigma) * self.jitter + Kxx += cola.ops.I_like(Kxx) * self.jitter + Kxx = cola.PSD(Kxx) - prior_distr = GaussianDistribution(mx.flatten(), Sigma) - if mx.shape[1] == 1: - return prior_distr - else: - return ReshapedDistribution(prior_distr, mx.shape) + return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx) def sample_approx( self, @@ -360,7 +339,7 @@ class AbstractPosterior(Module): likelihood: AbstractLikelihood jitter: float = static_field(1e-6) - def __call__(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: + def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: r"""Evaluate the Gaussian process posterior at the given points. The output of this function is a @@ -378,13 +357,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: Returns ------- - ReshapedGaussianDistribution: A multivariate normal random variable representation - of the Gaussian process, possibly with reshaped events. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: r"""Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the `AbstractPrior` class, this method must be implemented. @@ -395,8 +374,8 @@ def predict(self, *args: Any, **kwargs: Any) -> ReshapedGaussianDistribution: Returns ------- - ReshapedGaussianDistribution: A multivariate normal random variable representation - of the Gaussian process, possibly with reshaped events. + GaussianDistribution: A multivariate normal random variable representation + of the Gaussian process. """ raise NotImplementedError @@ -448,7 +427,7 @@ def predict( self, test_inputs: Num[Array, "N D"], train_data: Dataset, - ) -> ReshapedGaussianDistribution: + ) -> GaussianDistribution: r"""Query the predictive posterior distribution. Conditional on a training data set, compute the GP's posterior @@ -495,66 +474,41 @@ def predict( Returns ------- - ReshapedGaussianDistribution: A - function that accepts an input array and returns the predictive - distribution as a `GaussianDistribution` or a `ReshapedDistribution[GaussianDistribution]`. + GaussianDistribution: A function that accepts an input array and + returns the predictive distribution as a `GaussianDistribution`. """ # Unpack training data - x, y, n_train, mask = train_data.X, train_data.y, train_data.n, train_data.mask - m = y.shape[1] - if m > 1 and mask is not None: - mask = mask.flatten() + x, y = train_data.X, train_data.y + # Unpack test inputs t = test_inputs - n_test = len(test_inputs) # Observation noise o² - obs_var = self.likelihood.obs_stddev**2 + obs_noise = self.likelihood.obs_stddev**2 mx = self.prior.mean_function(x) # Precompute Gram matrix, Kxx, at training inputs, x Kxx = self.prior.kernel.gram(x) - Kyy = self.prior.out_kernel.gram(jnp.arange(m)[:, jnp.newaxis]) + Kxx += cola.ops.I_like(Kxx) * self.jitter # Σ = Kxx + Io² - Sigma = cola.ops.Kronecker(Kxx, Kyy) - Sigma += cola.ops.I_like(Sigma) * (obs_var + self.jitter) + Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise Sigma = cola.PSD(Sigma) - if mask is not None: - y = jnp.where(mask, 0.0, y) - mx = jnp.where(mask, 0.0, mx) - Sigma_masked = jnp.where(mask + mask.T, 0.0, Sigma.to_dense()) - Sigma = cola.PSD( - Dense( - jnp.where( - jnp.diag(jnp.squeeze(mask)), 1 / (2 * jnp.pi), Sigma_masked - ) - ) - ) - mean_t = self.prior.mean_function(t) - Ktt = cola.ops.Kronecker(self.prior.kernel.gram(t), Kyy) - Ktt = cola.PSD(Ktt) - Kxt = cola.ops.Kronecker(self.prior.kernel.cross_covariance(x, t), Kyy) - - # Σ⁻¹ Kxt - if mask is not None: - Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt) - Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky()) + Ktt = self.prior.kernel.gram(t) + Kxt = self.prior.kernel.cross_covariance(x, t) + Sigma_inv_Kxt = cola.solve(Sigma, Kxt) # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) - mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten() + mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx) # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. - covariance = Ktt - Kxt.T @ Sigma_inv_Kxt + covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) covariance += cola.ops.I_like(covariance) * self.prior.jitter covariance = cola.PSD(covariance) - rval = GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) - if m == 1: - return rval - else: - return ReshapedDistribution(rval, (n_test, m)) + + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) def sample_approx( self, diff --git a/gpjax/integrators.py b/gpjax/integrators.py index 3ab1b529f..634ca0d3b 100644 --- a/gpjax/integrators.py +++ b/gpjax/integrators.py @@ -155,7 +155,7 @@ def integrate( log2pi = jnp.log(2.0 * jnp.pi) val = jnp.sum( log2pi + jnp.log(obs_stddev**2) + (sq_error + variance) / obs_stddev**2, - axis=1 + axis=1, ) return -0.5 * val diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 3e01404e9..a3f86352f 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -28,10 +28,7 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from gpjax.kernels.non_euclidean import ( - CatKernel, - GraphKernel, -) +from gpjax.kernels.non_euclidean import GraphKernel from gpjax.kernels.nonstationary import ( ArcCosine, Linear, @@ -54,7 +51,6 @@ "Constant", "RBF", "GraphKernel", - "CatKernel", "Matern12", "Matern32", "Matern52", diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 775e469ff..ff9e7f8b6 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -24,7 +24,10 @@ Union, ) import jax.numpy as jnp -from jaxtyping import Num +from jaxtyping import ( + Float, + Num, +) import tensorflow_probability.substrates.jax.distributions as tfd from gpjax.base import ( @@ -60,18 +63,18 @@ def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]): def gram(self, x: Num[Array, "N D"]): return self.compute_engine.gram(self, x) - def slice_input(self, x: Num[Array, "... D"]) -> Num[Array, "... Q"]: + def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]: r"""Slice out the relevant columns of the input matrix. Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. Args: - x (Num[Array, "... D"]): The matrix or vector that is to be sliced. + x (Float[Array, "... D"]): The matrix or vector that is to be sliced. Returns ------- - Num[Array, "... Q"]: A sliced form of the input matrix. + Float[Array, "... Q"]: A sliced form of the input matrix. """ return x[..., self.active_dims] if self.active_dims is not None else x @@ -147,12 +150,12 @@ class Constant(AbstractKernel): constant: ScalarFloat = param_field(jnp.array(0.0)) - def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Evaluate the kernel on a pair of inputs. Args: - x (Num[Array, " D"]): The left hand input of the kernel function. - y (Num[Array, " D"]): The right hand input of the kernel function. + x (Float[Array, " D"]): The left hand input of the kernel function. + y (Float[Array, " D"]): The right hand input of the kernel function. Returns ------- @@ -185,14 +188,14 @@ def __post_init__(self): def __call__( self, - x: Num[Array, " D"], - y: Num[Array, " D"], + x: Float[Array, " D"], + y: Float[Array, " D"], ) -> ScalarFloat: r"""Evaluate the kernel on a pair of inputs. Args: - x (Num[Array, " D"]): The left hand input of the kernel function. - y (Num[Array, " D"]): The right hand input of the kernel function. + x (Float[Array, " D"]): The left hand input of the kernel function. + y (Float[Array, " D"]): The right hand input of the kernel function. Returns ------- diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index 2b93a1233..ac48b8101 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -47,7 +47,7 @@ def gram( Args: kernel (AbstractKernel): the kernel function. - x (Float[Array, "N N"]): The inputs to the kernel function. + x (Num[Array, "N N"]): The inputs to the kernel function. Returns ------- @@ -65,8 +65,8 @@ def cross_covariance( Args: kernel (AbstractKernel): the kernel function. - x (Float[Array,"N D"]): The first input matrix. - y (Float[Array,"M D"]): The second input matrix. + x (Num[Array,"N D"]): The first input matrix. + y (Num[Array,"M D"]): The second input matrix. Returns ------- diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 0ddbd0251..c7dd8639e 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -23,10 +23,7 @@ ) from jax import vmap import jax.numpy as jnp -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float from gpjax.kernels.computations import AbstractKernelComputation from gpjax.typing import Array @@ -35,14 +32,14 @@ class ConstantDiagonalKernelComputation(AbstractKernelComputation): - def gram(self, kernel: Kernel, x: Num[Array, "N D"]) -> LinearOperator: + def gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator: r"""Compute the Gram matrix. Compute Gram covariance operator of the kernel function. Args: kernel (Kernel): the kernel function. - x (Num[Array, "N D"]): The inputs to the kernel function. + x (Float[Array, "N D"]): The inputs to the kernel function. Returns ------- @@ -54,7 +51,7 @@ def gram(self, kernel: Kernel, x: Num[Array, "N D"]) -> LinearOperator: return PSD(jnp.atleast_1d(value) * Identity(shape=shape, dtype=dtype)) - def diagonal(self, kernel: Kernel, inputs: Num[Array, "N D"]) -> Diagonal: + def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal: r"""Compute the diagonal Gram matrix's entries. For a given kernel, compute the elementwise diagonal of the @@ -62,7 +59,7 @@ def diagonal(self, kernel: Kernel, inputs: Num[Array, "N D"]) -> Diagonal: Args: kernel (Kernel): the kernel function. - inputs (Num[Array, "N D"]): The input matrix. + inputs (Float[Array, "N D"]): The input matrix. Returns ------- @@ -73,7 +70,7 @@ def diagonal(self, kernel: Kernel, inputs: Num[Array, "N D"]) -> Diagonal: return PSD(Diagonal(diag=diag)) def cross_covariance( - self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"] + self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: r"""Compute the cross-covariance matrix. @@ -82,8 +79,8 @@ def cross_covariance( Args: kernel (Kernel): the kernel function. - x (Num[Array,"N D"]): The input matrix. - y (Num[Array,"M D"]): The input matrix. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. Returns ------- diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index e95e21d8f..3ad958a26 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -15,10 +15,7 @@ import beartype.typing as tp from jax import vmap -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float from gpjax.kernels.computations.base import AbstractKernelComputation from gpjax.typing import Array @@ -32,7 +29,7 @@ class DenseKernelComputation(AbstractKernelComputation): """ def cross_covariance( - self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"] + self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: r"""Compute the cross-covariance matrix. diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index 71343d400..d4c323da8 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -20,10 +20,7 @@ LinearOperator, ) from jax import vmap -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float from gpjax.kernels.computations import AbstractKernelComputation from gpjax.typing import Array @@ -36,7 +33,7 @@ class DiagonalKernelComputation(AbstractKernelComputation): a diagonal Gram matrix. """ - def gram(self, kernel: Kernel, x: Num[Array, "N D"]) -> LinearOperator: + def gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator: r"""Compute the Gram matrix. For a kernel with diagonal structure, compute the $`N\times N`$ Gram matrix on @@ -44,7 +41,7 @@ def gram(self, kernel: Kernel, x: Num[Array, "N D"]) -> LinearOperator: Args: kernel (Kernel): the kernel function. - x (Num[Array, "N D"]): The input matrix. + x (Float[Array, "N D"]): The input matrix. Returns ------- @@ -53,7 +50,7 @@ def gram(self, kernel: Kernel, x: Num[Array, "N D"]) -> LinearOperator: return PSD(Diagonal(diag=vmap(lambda x: kernel(x, x))(x))) def cross_covariance( - self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"] + self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: r"""Compute the cross-covariance matrix. @@ -62,8 +59,8 @@ def cross_covariance( Args: kernel (Kernel): the kernel function. - x (Num[Array,"N D"]): The input matrix. - y (Num[Array,"M D"]): The input matrix. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. Returns ------- diff --git a/gpjax/kernels/non_euclidean/__init__.py b/gpjax/kernels/non_euclidean/__init__.py index ee45287b0..d364bc71b 100644 --- a/gpjax/kernels/non_euclidean/__init__.py +++ b/gpjax/kernels/non_euclidean/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== -from gpjax.kernels.non_euclidean.categorical import CatKernel from gpjax.kernels.non_euclidean.graph import GraphKernel -__all__ = ["GraphKernel", "CatKernel"] +__all__ = ["GraphKernel"] diff --git a/gpjax/kernels/non_euclidean/categorical.py b/gpjax/kernels/non_euclidean/categorical.py deleted file mode 100644 index 1b0c9ead2..000000000 --- a/gpjax/kernels/non_euclidean/categorical.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from dataclasses import dataclass -from typing import ( - NamedTuple, - Union, -) - -import jax.numpy as jnp -from jaxtyping import ( - Float, - Int, -) -import tensorflow_probability.substrates.jax as tfp - -from gpjax.base import ( - param_field, - static_field, -) -from gpjax.kernels.base import AbstractKernel -from gpjax.typing import ( - Array, - ScalarInt, -) - -tfb = tfp.bijectors - -CatKernelParams = NamedTuple( - "CatKernelParams", - [("stddev", Float[Array, "N 1"]), ("cholesky_lower", Float[Array, " N*(N-1)//2"])], -) - - -@dataclass -class CatKernel(AbstractKernel): - r"""The categorical kernel is defined for a fixed number of values of categorical input. - - It stores a standard dev for each input value (i.e. the diagonal of the gram), and a lower cholesky factor for correlations. - It returns the corresponding values from an the gram matrix when called. - - Args: - stddev (Float[Array, "N"]): The standard deviation parameters, one for each input space value. - cholesky_lower (Float[Array, "N*(N-1)//2 N"]): The parameters for the Cholesky factor of the gram matrix. - inspace_vals (list): The values in the input space this CatKernel works for. Stored for order reference, making clear the indices used for each input space value. - name (str): The name of the kernel. - input_1hot (bool): If True, the kernel expect to be called with a 1-hot encoding of the input space values. If False, it expects the indices of the input space values. - - Raises: - ValueError: If the number of diagonal variance parameters does not match the number of input space values. - """ - - stddev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus()) - cholesky_lower: Float[Array, "N N"] = param_field( - jnp.eye(2), bijector=tfb.CorrelationCholesky() - ) - inspace_vals: Union[list, None] = static_field(None) - name: str = "Categorical Kernel" - input_1hot: bool = static_field(False) - - def __post_init__(self): - if self.inspace_vals is not None and len(self.inspace_vals) != len(self.stddev): - raise ValueError( - f"The number of stddev parameters ({len(self.stddev)}) has to match the number of input space values ({len(self.inspace_vals)}), unless inspace_vals is None." - ) - - @property - def explicit_gram(self) -> Float[Array, "N N"]: - """Access the PSD gram matrix resulting from the parameters. - - Returns: - Float[Array, "N N"]: The gram matrix. - """ - L = self.stddev.reshape(-1, 1) * self.cholesky_lower - return L @ L.T - - def __call__( # TODO not consistent with general kernel interface - self, - x: Union[ScalarInt, Int[Array, " N"]], - y: Union[ScalarInt, Int[Array, " N"]], - ): - r"""Compute the (co)variance between a pair of dictionary indices. - - Args: - x (Union[ScalarInt, Int[Array, "N"]]): The index of the first dictionary entry, or its one-hot encoding. - y (Union[ScalarInt, Int[Array, "N"]]): The index of the second dictionary entry, or its one-hot encoding. - - Returns - ------- - ScalarFloat: The value of $k(v_i, v_j)$. - """ - try: - x = x.squeeze() - y = y.squeeze() - except AttributeError: - pass - if self.input_1hot: - return self.explicit_gram[jnp.outer(x, y) == 1] - else: - return self.explicit_gram[x, y] - - @staticmethod - def num_cholesky_lower_params(num_inspace_vals: ScalarInt) -> ScalarInt: - """Compute the number of parameters required to store the lower triangular Cholesky factor of the gram matrix. - - Args: - num_inspace_vals (ScalarInt): The number of values in the input space. - - Returns: - ScalarInt: The number of parameters required to store the lower triangle of the Cholesky factor of the gram matrix. - """ - return num_inspace_vals * (num_inspace_vals - 1) // 2 - - @staticmethod - def gram_to_stddev_cholesky_lower(gram: Float[Array, "N N"]) -> CatKernelParams: - """Compute the standard deviation and lower triangular Cholesky factor of the gram matrix. - - Args: - gram (Float[Array, "N N"]): The gram matrix. - - Returns: - tuple[Float[Array, "N"], Float[Array, "N N"]]: The standard deviation and lower triangular Cholesky factor of the gram matrix, where the latter is scaled to result in unit variances. - """ - stddev = jnp.sqrt(jnp.diag(gram)) - L = jnp.linalg.cholesky(gram) / stddev.reshape(-1, 1) - return CatKernelParams(stddev, L) diff --git a/gpjax/kernels/non_euclidean/utils.py b/gpjax/kernels/non_euclidean/utils.py index 2d8f2189d..eeda01350 100644 --- a/gpjax/kernels/non_euclidean/utils.py +++ b/gpjax/kernels/non_euclidean/utils.py @@ -14,23 +14,23 @@ # ============================================================================== from jaxtyping import ( + Float, Int, - Num, ) from gpjax.typing import Array def jax_gather_nd( - params: Num[Array, " N *rest"], indices: Int[Array, " M 1"] -) -> Num[Array, " M *rest"]: + params: Float[Array, " N *rest"], indices: Int[Array, " M 1"] +) -> Float[Array, " M *rest"]: r"""Slice a `params` array at a set of `indices`. This is a reimplementation of TensorFlow's `gather_nd` function: [link](https://www.tensorflow.org/api_docs/python/tf/gather_nd) Args: - params (Num[Array]): An arbitrary array with leading axes of length $N$ upon + params (Float[Array]): An arbitrary array with leading axes of length $N$ upon which we shall slice. indices (Float[Int]): An integer array of length $M$ with values in the range $[0, N)$ whose value at index $i$ will be used to slice `params` at @@ -38,7 +38,7 @@ def jax_gather_nd( Returns ------- - Num[Array: An arbitrary array with leading axes of length $M$. + Float[Array: An arbitrary array with leading axes of length $M$. """ tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) return params[tuple_indices] diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index d7fa2aaa7..ac3b79699 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -17,10 +17,7 @@ from beartype.typing import Union import jax.numpy as jnp -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd @@ -48,8 +45,8 @@ class Matern32(AbstractKernel): def __call__( self, - x: Num[Array, " D"], - y: Num[Array, " D"], + x: Float[Array, " D"], + y: Float[Array, " D"], ) -> ScalarFloat: r"""Compute the Matérn 3/2 kernel between a pair of arrays. @@ -61,8 +58,8 @@ def __call__( ``` Args: - x (Num[Array, " D"]): The left hand argument of the kernel function's call. - y (Num[Array, " D"]): The right hand argument of the kernel function's call. + x (Float[Array, " D"]): The left hand argument of the kernel function's call. + y (Float[Array, " D"]): The right hand argument of the kernel function's call. Returns ------- diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 2a8ed0564..6a57813c2 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -17,10 +17,7 @@ from beartype.typing import Union import jax.numpy as jnp -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd @@ -46,7 +43,7 @@ class Matern52(AbstractKernel): variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Matérn52" - def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the Matérn 5/2 kernel between a pair of arrays. Evaluate the kernel on a pair of inputs $`(x, y)`$ with @@ -56,8 +53,8 @@ def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: ``` Args: - x (Num[Array, " D"]): The left hand argument of the kernel function's call. - y (Num[Array, " D"]): The right hand argument of the kernel function's call. + x (Float[Array, " D"]): The left hand argument of the kernel function's call. + y (Float[Array, " D"]): The right hand argument of the kernel function's call. Returns ------- diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 8419d528a..1753b82c7 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -17,10 +17,7 @@ from beartype.typing import Union import jax.numpy as jnp -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import param_field @@ -45,7 +42,7 @@ class Periodic(AbstractKernel): period: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Periodic" - def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the Periodic kernel between a pair of arrays. Evaluate the kernel on a pair of inputs $`(x, y)`$ with length-scale parameter $`\ell`$, variance $`\sigma^2`$ @@ -55,8 +52,8 @@ def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: ``` Args: - x (Num[Array, " D"]): The left hand argument of the kernel function's call. - y (Num[Array, " D"]): The right hand argument of the kernel function's call + x (Float[Array, " D"]): The left hand argument of the kernel function's call. + y (Float[Array, " D"]): The right hand argument of the kernel function's call Returns: ScalarFloat: The value of $`k(x, y)`$. """ diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index f7a09900f..6f2cd2b56 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -17,10 +17,7 @@ from beartype.typing import Union import jax.numpy as jnp -from jaxtyping import ( - Float, - Num, -) +from jaxtyping import Float import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd @@ -43,7 +40,7 @@ class RBF(AbstractKernel): variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "RBF" - def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the RBF kernel between a pair of arrays. Evaluate the kernel on a pair of inputs $`(x, y)`$ with lengthscale parameter @@ -53,8 +50,8 @@ def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: ``` Args: - x (Num[Array, " D"]): The left hand argument of the kernel function's call. - y (Num[Array, " D"]): The right hand argument of the kernel function's call. + x (Float[Array, " D"]): The left hand argument of the kernel function's call. + y (Float[Array, " D"]): The right hand argument of the kernel function's call. Returns: ScalarFloat: The value of $`k(x, y)`$. diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 4a1f8da65..355649317 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -16,7 +16,7 @@ from dataclasses import dataclass import jax.numpy as jnp -from jaxtyping import Num +from jaxtyping import Float import tensorflow_probability.substrates.jax.bijectors as tfb from gpjax.base import ( @@ -42,7 +42,7 @@ class White(AbstractKernel): ) name: str = "White" - def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: + def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the White noise kernel between a pair of arrays. Evaluate the kernel on a pair of inputs $`(x, y)`$ with variance $`\sigma^2`$: @@ -51,8 +51,8 @@ def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat: ``` Args: - x (Num[Array, " D"]): The left hand argument of the kernel function's call. - y (Num[Array, " D"]): The right hand argument of the kernel function's call. + x (Float[Array, " D"]): The left hand argument of the kernel function's call. + y (Float[Array, " D"]): The right hand argument of the kernel function's call. Returns ------- diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index b110b88f5..b7fcb724b 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -29,10 +29,7 @@ param_field, static_field, ) -from gpjax.distributions import ( - GaussianDistribution, - ReshapedDistribution, -) +from gpjax.distributions import GaussianDistribution from gpjax.integrators import ( AbstractIntegrator, AnalyticalGaussianIntegrator, @@ -167,22 +164,11 @@ def predict( ------- tfd.Distribution: The predictive distribution. """ - mean = dist.mean() - event_size = mean.size - + n_data = dist.event_shape[0] cov = dist.covariance() + noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_stddev**2) - # reshape for handling multi-output case - mean = mean.flatten() - cov = cov.reshape([mean.size] * 2) - - noisy_cov = cov.at[jnp.diag_indices(event_size)].add(self.obs_stddev**2) - - likelihood_distr = tfd.MultivariateNormalFullCovariance(mean, noisy_cov) - if len(dist.event_shape) == 1: - return likelihood_distr - else: - return ReshapedDistribution(likelihood_distr, dist.event_shape) + return tfd.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) @dataclass diff --git a/gpjax/objectives.py b/gpjax/objectives.py index afb485e5e..c07290c48 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -132,30 +132,22 @@ def step( ScalarFloat: The marginal log-likelihood of the Gaussian process for the current parameter set. """ - x, y, mask = train_data.X, train_data.y, train_data.mask - m = y.shape[1] - if m > 1 and mask is not None: - mask = mask.flatten() + x, y = train_data.X, train_data.y # Observation noise o² - obs_var = posterior.likelihood.obs_stddev**2 - + obs_noise = posterior.likelihood.obs_stddev**2 mx = posterior.prior.mean_function(x) # Σ = (Kxx + Io²) = LLᵀ Kxx = posterior.prior.kernel.gram(x) - Kyy = posterior.prior.out_kernel.gram(jnp.arange(m)[:, jnp.newaxis]) - - Sigma = cola.ops.Kronecker(Kxx, Kyy) - Sigma = Sigma + cola.ops.I_like(Sigma) * (obs_var + posterior.prior.jitter) + Kxx += cola.ops.I_like(Kxx) * posterior.prior.jitter + Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise Sigma = cola.PSD(Sigma) - # flatten to handle multi-output case, then calculate # p(y | x, θ), where θ are the model hyperparameters: - mll = GaussianDistribution(jnp.atleast_1d(mx.flatten()), Sigma) + mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma) - rval = mll.log_prob(jnp.atleast_1d(y.flatten()), mask=mask).squeeze() - return self.constant * rval + return self.constant * (mll.log_prob(jnp.atleast_1d(y.squeeze())).squeeze()) class ConjugateLOOCV(AbstractObjective): @@ -220,15 +212,8 @@ def step( ScalarFloat: The leave-one-out log predictive probability of the Gaussian process for the current parameter set. """ - x, y, mask = train_data.X, train_data.y, train_data.mask - m = y.shape[1] - - if mask is not None: - raise NotImplementedError("ConjugateLOOCV does not yet support masking") - if m > 1: - raise NotImplementedError( - "ConjugateLOOCV does not yet support multi-output" - ) + x, y = train_data.X, train_data.y + y.shape[1] # Observation noise o² obs_var = posterior.likelihood.obs_stddev**2 @@ -237,8 +222,6 @@ def step( # Σ = (Kxx + Io²) Kxx = posterior.prior.kernel.gram(x) - Kyy = posterior.prior.out_kernel.gram(jnp.arange(m)[:, jnp.newaxis]) - Sigma = cola.ops.Kronecker(Kxx, Kyy) Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + posterior.prior.jitter) Sigma = cola.PSD(Sigma) # [N, N] diff --git a/mkdocs.yml b/mkdocs.yml index e62a4dfd6..b0f5d6b0a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,7 +27,6 @@ nav: - Graph kernels: examples/graph_kernels.py - Sparse GPs: examples/uncollapsed_vi.py - Stochastic sparse GPs: examples/collapsed_vi.py - - Pathwise Sampling for Spatial Modelling: examples/spatial.py - Bayesian Optimisation: examples/bayesian_optimisation.py - Decision Making: examples/decision_making.py - Multi-output GPs for Ocean Modelling: examples/oceanmodelling.py diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d65cce636..9567219d7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -24,7 +24,6 @@ from jax import config import jax.numpy as jnp -import jax.random as jr import jax.tree_util as jtu import pytest @@ -34,26 +33,20 @@ @pytest.mark.parametrize("n", [1, 2, 10]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) @pytest.mark.parametrize("in_dim", [1, 2, 10]) -def test_dataset_init(n: int, in_dim: int, out_dim: int) -> None: +def test_dataset_init(n: int, in_dim: int) -> None: # Create dataset x = jnp.ones((n, in_dim)) - y = jnp.ones((n, out_dim)) + y = jnp.ones((n, 1)) D = Dataset(X=x, y=y) # Test dataset shapes assert D.n == n assert D.in_dim == in_dim - assert D.out_dim == out_dim # Test representation - assert ( - D.__repr__() - == f"- Number of observations: {n}\n- Input dimension: {in_dim}\n- Output" - f" dimension: {out_dim}" - ) + assert D.__repr__() == f"- Number of observations: {n}\n- Input dimension: {in_dim}" # Ensure dataclass assert is_dataclass(D) @@ -68,17 +61,16 @@ def test_dataset_init(n: int, in_dim: int, out_dim: int) -> None: @pytest.mark.parametrize("n1", [1, 2, 10]) @pytest.mark.parametrize("n2", [1, 2, 10]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) @pytest.mark.parametrize("in_dim", [1, 2, 10]) -def test_dataset_add(n1: int, n2: int, in_dim: int, out_dim: int) -> None: +def test_dataset_add(n1: int, n2: int, in_dim: int) -> None: # Create first dataset x1 = jnp.ones((n1, in_dim)) - y1 = jnp.ones((n1, out_dim)) + y1 = jnp.ones((n1, 1)) D1 = Dataset(X=x1, y=y1) # Create second dataset x2 = 2 * jnp.ones((n2, in_dim)) - y2 = 2 * jnp.ones((n2, out_dim)) + y2 = 2 * jnp.ones((n2, 1)) D2 = Dataset(X=x2, y=y2) # Add datasets @@ -87,13 +79,11 @@ def test_dataset_add(n1: int, n2: int, in_dim: int, out_dim: int) -> None: # Test shapes assert D.n == n1 + n2 assert D.in_dim == in_dim - assert D.out_dim == out_dim # Test representation assert ( D.__repr__() - == f"- Number of observations: {n1 + n2}\n- Input dimension: {in_dim}\n- Output" - f" dimension: {out_dim}" + == f"- Number of observations: {n1 + n2}\n- Input dimension: {in_dim}" ) # Ensure dataclass @@ -111,12 +101,11 @@ def test_dataset_add(n1: int, n2: int, in_dim: int, out_dim: int) -> None: @pytest.mark.parametrize(("nx", "ny"), [(1, 2), (2, 1), (10, 5), (5, 10)]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) @pytest.mark.parametrize("in_dim", [1, 2, 10]) -def test_dataset_incorrect_lengths(nx: int, ny: int, out_dim: int, in_dim: int) -> None: +def test_dataset_incorrect_lengths(nx: int, ny: int, in_dim: int) -> None: # Create input and output pairs of different lengths x = jnp.ones((nx, in_dim)) - y = jnp.ones((ny, out_dim)) + y = jnp.ones((ny, 1)) # Ensure error is raised upon dataset creation with pytest.raises(ValidationErrors): @@ -124,9 +113,8 @@ def test_dataset_incorrect_lengths(nx: int, ny: int, out_dim: int, in_dim: int) @pytest.mark.parametrize("n", [1, 2, 10]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) @pytest.mark.parametrize("in_dim", [1, 2, 10]) -def test_2d_inputs(n: int, out_dim: int, in_dim: int) -> None: +def test_2d_inputs(n: int, in_dim: int) -> None: # Create dataset where output dimension is incorrectly not 2D x = jnp.ones((n, in_dim)) y = jnp.ones((n,)) @@ -137,7 +125,7 @@ def test_2d_inputs(n: int, out_dim: int, in_dim: int) -> None: # Create dataset where input dimension is incorrectly not 2D x = jnp.ones((n,)) - y = jnp.ones((n, out_dim)) + y = jnp.ones((n, 1)) # Ensure error is raised upon dataset creation with pytest.raises(ValidationErrors): @@ -161,45 +149,6 @@ def test_y_none(n: int, in_dim: int) -> None: assert jtu.tree_leaves(D) == [x] -@pytest.mark.parametrize("n", [1, 2, 10]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) -@pytest.mark.parametrize("in_dim", [1, 2, 10]) -def test_dataset_missing(n: int, in_dim: int, out_dim: int) -> None: - # Create dataset - x = jnp.ones((n, in_dim)) - y = jr.normal(jr.PRNGKey(123), (n, out_dim)) - y = y.at[y < 0].set(jnp.nan) - mask = jnp.isnan(y) - D = Dataset(X=x, y=y) - - # Check mask - assert D.mask is not None - assert jnp.array_equal(D.mask, mask) - - # Create second dataset - x2 = 2 * jnp.ones((n, in_dim)) - y2 = 2 * jnp.ones((n, out_dim)) - D2 = Dataset(X=x2, y=y2) - - # Add datasets - D2 = D + D2 - - # Check mask - assert jnp.sum(D2.mask) == jnp.sum(D.mask) - - # Test dataset shapes - assert D.n == n - assert D.in_dim == in_dim - assert D.out_dim == out_dim - - # Check tree flatten - # lexicographic order: uppercase "X" comes before lowercase "m" - x_, mask_, y_ = jtu.tree_leaves(D) - assert jnp.allclose(x, x_) - assert jnp.array_equal(mask, mask_) - assert jnp.allclose(y, y_, equal_nan=True) - - @pytest.mark.parametrize( ("prec_x", "prec_y"), [ @@ -210,13 +159,12 @@ def test_dataset_missing(n: int, in_dim: int, out_dim: int) -> None: ) @pytest.mark.parametrize("n", [1, 2, 10]) @pytest.mark.parametrize("in_dim", [1, 2, 10]) -@pytest.mark.parametrize("out_dim", [1, 2, 10]) def test_precision_warning( - n: int, in_dim: int, out_dim: int, prec_x: jnp.dtype, prec_y: jnp.dtype + n: int, in_dim: int, prec_x: jnp.dtype, prec_y: jnp.dtype ) -> None: # Create dataset x = jnp.ones((n, in_dim)).astype(prec_x) - y = jnp.ones((n, out_dim)).astype(prec_y) + y = jnp.ones((n, 1)).astype(prec_y) # Check for warnings if dtypes are not float64 expected_warnings = 0 diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py index 0df92af16..db8586c51 100644 --- a/tests/test_gaussian_distribution.py +++ b/tests/test_gaussian_distribution.py @@ -161,24 +161,3 @@ def test_kl_divergence(n: int) -> None: with pytest.raises(ValueError): incompatible = GaussianDistribution(loc=jnp.ones((2 * n,))) incompatible.kl_divergence(dist_a) - - -@pytest.mark.parametrize("n", [5, 100]) -def test_masked_log_prob(n): - key_mean, key_sqrt = jr.split(_key, 2) - mean = jr.uniform(key_mean, shape=(n,)) - sqrt = jr.uniform(key_sqrt, shape=(n, n)) - covariance = sqrt @ sqrt.T - y = jr.normal(_key, shape=(n,)) - y = y.at[jr.choice(key_sqrt, y.shape[0], (1,))].set(jnp.nan) - mask = jnp.isnan(y) - - # check that cholesky does not error - _L = jnp.linalg.cholesky(covariance) # noqa: F841 - - # check that masked log_prob is equal to tfp log_prob with missing values removed - dist = GaussianDistribution(loc=mean, scale=Dense(covariance)) - tfp_dist = MultivariateNormalFullCovariance( - loc=mean[~mask], covariance_matrix=covariance[~mask][:, ~mask] - ) - assert approx_equal(dist.log_prob(y, mask=mask), tfp_dist.log_prob(y[~mask])) diff --git a/tests/test_gps.py b/tests/test_gps.py index 897eff7d9..59af2bb91 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -48,7 +48,6 @@ RBF, AbstractKernel, Matern52, - White, ) from gpjax.likelihoods import ( AbstractLikelihood, @@ -81,15 +80,13 @@ def test_abstract_posterior(): @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF(), Matern52()]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) -@pytest.mark.parametrize("out_kernel", [RBF(), White()]) def test_prior( num_datapoints: int, kernel: AbstractKernel, mean_function: AbstractMeanFunction, - out_kernel: AbstractKernel, ) -> None: # Create prior. - prior = Prior(mean_function=mean_function, kernel=kernel, out_kernel=out_kernel) + prior = Prior(mean_function=mean_function, kernel=kernel) # Check types. assert isinstance(prior, Prior) @@ -99,7 +96,7 @@ def test_prior( # Check pytree. assert jtu.tree_leaves(prior) == jtu.tree_leaves(kernel) + jtu.tree_leaves( mean_function - ) + jtu.tree_leaves(out_kernel) + ) # Query a marginal distribution at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) @@ -119,12 +116,10 @@ def test_prior( @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF(), Matern52()]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) -@pytest.mark.parametrize("out_kernel", [RBF(), White()]) def test_conjugate_posterior( num_datapoints: int, mean_function: AbstractMeanFunction, kernel: AbstractKernel, - out_kernel: AbstractKernel, ) -> None: # Create a dataset. key = jr.PRNGKey(123) @@ -133,7 +128,7 @@ def test_conjugate_posterior( D = Dataset(X=x, y=y) # Define prior. - prior = Prior(mean_function=mean_function, kernel=kernel, out_kernel=out_kernel) + prior = Prior(mean_function=mean_function, kernel=kernel) # Define a likelihood. likelihood = Gaussian(num_datapoints=num_datapoints) @@ -148,7 +143,7 @@ def test_conjugate_posterior( # Check tree flattening. assert jtu.tree_leaves(posterior) == jtu.tree_leaves(likelihood) + jtu.tree_leaves( kernel - ) + jtu.tree_leaves(mean_function) + jtu.tree_leaves(out_kernel) + ) + jtu.tree_leaves(mean_function) # Query a marginal distribution of the posterior at some inputs. inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) @@ -165,66 +160,13 @@ def test_conjugate_posterior( assert sigma.shape == (num_datapoints, num_datapoints) -@pytest.mark.parametrize("num_datapoints", [1, 10]) -@pytest.mark.parametrize("kernel", [RBF(), Matern52()]) -@pytest.mark.parametrize( - "mean_function", - [Constant(constant=jnp.zeros((2,))), Constant(constant=jnp.ones((2,)))], -) -@pytest.mark.parametrize("out_kernel", [RBF(), Matern52()]) -def test_conjugate_posterior_mo( - num_datapoints: int, - mean_function: AbstractMeanFunction, - kernel: AbstractKernel, - out_kernel: AbstractKernel, -) -> None: - # Create a dataset. - key = jr.PRNGKey(123) - x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)) - y = ( - jnp.hstack([jnp.sin(x), jnp.cos(x)]) - + jr.normal(key=key, shape=(num_datapoints, 2)) * 0.1 - ) - D = Dataset(X=x, y=y) - - # Define prior. - prior = Prior(mean_function=mean_function, kernel=kernel, out_kernel=out_kernel) - - # Define a likelihood. - likelihood = Gaussian(num_datapoints=num_datapoints) - - # Construct the posterior via the class. - posterior = ConjugatePosterior(prior=prior, likelihood=likelihood) - - # Check types. - assert isinstance(posterior, ConjugatePosterior) - assert is_dataclass(posterior) - - # Check tree flattening. - assert jtu.tree_leaves(posterior) == jtu.tree_leaves(likelihood) + jtu.tree_leaves( - kernel - ) + jtu.tree_leaves(mean_function) + jtu.tree_leaves(out_kernel) - - # Query a marginal distribution of the posterior at some inputs. - inputs = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) - marginal_distribution = posterior(inputs, D) - - # Ensure that the marginal distribution has the correct shape. - mu = marginal_distribution.mean() - sigma = marginal_distribution.covariance() - assert mu.shape == (num_datapoints, 2) - assert sigma.shape == (num_datapoints, 2, num_datapoints, 2) - - @pytest.mark.parametrize("num_datapoints", [1, 10]) @pytest.mark.parametrize("kernel", [RBF(), Matern52()]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) -@pytest.mark.parametrize("out_kernel", [RBF(), White()]) def test_nonconjugate_posterior( num_datapoints: int, mean_function: AbstractMeanFunction, kernel: AbstractKernel, - out_kernel: AbstractKernel, ) -> None: # Create a dataset. key = jr.PRNGKey(123) @@ -258,7 +200,7 @@ def test_nonconjugate_posterior( ] leaves = jtu.tree_leaves(posterior) - for l1, l2 in zip(leaves, true_leaves): + for l1, l2 in zip(leaves, true_leaves, strict=True): assert (l1 == l2).all() # Query a marginal distribution of the posterior at some inputs. @@ -310,7 +252,9 @@ def test_posterior_construct( leaves_rmul = jtu.tree_leaves(posterior_rmul) leaves_manual = jtu.tree_leaves(posterior_manual) - for leaf_mul, leaf_rmul, leaf_man in zip(leaves_mul, leaves_rmul, leaves_manual): + for leaf_mul, leaf_rmul, leaf_man in zip( + leaves_mul, leaves_rmul, leaves_manual, strict=True + ): assert (leaf_mul == leaf_rmul).all() assert (leaf_rmul == leaf_man).all() diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index f8b997d8c..4ed6d68a6 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -13,13 +13,9 @@ from cola.ops import I_like from jax import config import jax.numpy as jnp -import jax.random as jr import networkx as nx -from gpjax.kernels.non_euclidean import ( - CatKernel, - GraphKernel, -) +from gpjax.kernels.non_euclidean import GraphKernel # # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -50,76 +46,3 @@ def test_graph_kernel(): Kxx += I_like(Kxx) * 1e-6 eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert all(eigen_values > 0) - - -def test_cat_kernel(): - x = jr.normal(jr.PRNGKey(123), (5000, 3)) - gram = jnp.cov(x.T) - params = CatKernel.gram_to_stddev_cholesky_lower(gram) - dk = CatKernel( - inspace_vals=list(range(len(gram))), - stddev=params.stddev, - cholesky_lower=params.cholesky_lower, - ) - assert jnp.allclose(dk.explicit_gram, gram) - - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - assert dict_kernel.stddev.shape == sdev.shape - assert jnp.allclose(dict_kernel.stddev, sdev) - assert jnp.allclose(dict_kernel.cholesky_lower, cholesky_lower) - assert dict_kernel.inspace_vals == inspace_vals - - -def test_cat_kernel_gram_to_stddev_cholesky_lower(): - gram = jnp.array([[1.0, 0.5], [0.5, 1.0]]) - sdev_expected = jnp.array([1.0, 1.0]) - cholesky_lower_expected = jnp.array([[1.0, 0.0], [0.5, 0.8660254]]) - - # Compute sdev and cholesky_lower from gram - sdev, cholesky_lower = CatKernel.gram_to_stddev_cholesky_lower(gram) - - assert jnp.allclose(sdev, sdev_expected) - assert jnp.allclose(cholesky_lower, cholesky_lower_expected) - - -def test_cat_kernel_call(): - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - # Compute kernel value for pair of inputs - kernel_value = dict_kernel.__call__(0, 1) - - assert jnp.allclose(kernel_value, 0.0) # since cholesky_lower is identity matrix - - -def test_cat_kernel_explicit_gram(): - sdev = jnp.ones((2,)) - cholesky_lower = jnp.eye(2) - inspace_vals = [0.0, 1.0] - - # Initialize CatKernel object - dict_kernel = CatKernel( - stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals - ) - - # Compute explicit gram matrix - explicit_gram = dict_kernel.explicit_gram - - assert explicit_gram.shape == (2, 2) - assert jnp.allclose( - explicit_gram, jnp.eye(2) - ) # since sdev are ones and cholesky_lower is identity matrix diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index d3513a6f9..803e44d23 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -147,7 +147,9 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: def prod(inp): - return [dict(zip(inp.keys(), values)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestLinear(BaseTestKernel): diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 2a7a27305..3e214b45b 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -172,7 +172,9 @@ def test_isotropic(self, dim: int): def prod(inp): - return [dict(zip(inp.keys(), values)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestRBF(BaseTestKernel): diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 2cded3de1..62e829712 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -141,7 +141,9 @@ def _test_call_check(likelihood, latent_mean, latent_cov, latent_dist): def prod(inp): - return [dict(zip(inp.keys(), values)) for values in product(*inp.values())] + return [ + dict(zip(inp.keys(), values, strict=True)) for values in product(*inp.values()) + ] class TestGaussian(BaseTestLikelihood): diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index d2a27308a..8a576c2c7 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -139,7 +139,7 @@ def test_variational_gaussians( + [diag_matrix_val(1.0)(n_inducing)] ) - for l1, l2 in zip(jtu.tree_leaves(q), true_leaves): + for l1, l2 in zip(jtu.tree_leaves(q), true_leaves, strict=True): assert (l1 == l2).all() elif isinstance(q, WhitenedVariationalGaussian): @@ -155,7 +155,7 @@ def test_variational_gaussians( + [diag_matrix_val(1.0)(n_inducing)] ) - for l1, l2 in zip(jtu.tree_leaves(q), true_leaves): + for l1, l2 in zip(jtu.tree_leaves(q), true_leaves, strict=True): assert (l1 == l2).all() elif isinstance(q, NaturalVariationalGaussian): @@ -172,7 +172,7 @@ def test_variational_gaussians( + jtu.tree_leaves(posterior) ) - for l1, l2 in zip(jtu.tree_leaves(q), true_leaves): + for l1, l2 in zip(jtu.tree_leaves(q), true_leaves, strict=True): assert (l1 == l2).all() elif isinstance(q, ExpectationVariationalGaussian): @@ -189,7 +189,7 @@ def test_variational_gaussians( + jtu.tree_leaves(posterior) ) - for l1, l2 in zip(jtu.tree_leaves(q), true_leaves): + for l1, l2 in zip(jtu.tree_leaves(q), true_leaves, strict=True): assert (l1 == l2).all() # Test KL @@ -264,6 +264,6 @@ def test_collapsed_variational_gaussian( # Test pytree structure (nodes are alphabetically flattened, hence the ordering) true_leaves = [inducing_inputs, *jtu.tree_leaves(posterior)] - for l1, l2 in zip(jtu.tree_leaves(variational_family), true_leaves): + for l1, l2 in zip(jtu.tree_leaves(variational_family), true_leaves, strict=True): assert l1.shape == l2.shape assert (l1 == l2).all()