Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README #220

Merged
merged 3 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .mailmap

This file was deleted.

118 changes: 58 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,32 @@
| [**Documentation**](https://gpjax.readthedocs.io/en/latest/)
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)

GPJax aims to provide a low-level interface to Gaussian process (GP) models in [Jax](https://github.com/google/jax), structured to give researchers maximum flexibility in extending the code to suit their own needs. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.
GPJax aims to provide a low-level interface to Gaussian process (GP) models in
[Jax](https://github.com/google/jax), structured to give researchers maximum
flexibility in extending the code to suit their own needs. The idea is that the
code should be as close as possible to the maths we write on paper when working
with GP models.

# Package support

GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the maintenance of GPJax is undertaken by [Thomas Pinder](https://github.com/thomaspinder) and [Daniel Dodd](https://github.com/Daniel-Dodd).

We would be delighted to receive contributions from interested individuals and groups. To learn how you can get involved, please read our [guide for contributing](https://github.com/JaxGaussianProcesses/GPJax/blob/master/CONTRIBUTING.md). If you have any questions, we encourage you to [open an issue](https://github.com/JaxGaussianProcesses/GPJax/issues/new/choose). For broader conversations, such as best GP fitting practices or questions about the mathematics of GPs, we invite you to [open a discussion](https://github.com/JaxGaussianProcesses/GPJax/discussions).

Feel free to join our [Slack Channel](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw), where we can discuss the development of GPJax and broader support for Gaussian process modelling.
GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today,
the maintenance of GPJax is undertaken by [Thomas
Pinder](https://github.com/thomaspinder) and [Daniel
Dodd](https://github.com/Daniel-Dodd).

We would be delighted to receive contributions from interested individuals and
groups. To learn how you can get involved, please read our [guide for
contributing](https://github.com/JaxGaussianProcesses/GPJax/blob/master/CONTRIBUTING.md).
If you have any questions, we encourage you to [open an
issue](https://github.com/JaxGaussianProcesses/GPJax/issues/new/choose). For
broader conversations, such as best GP fitting practices or questions about the
mathematics of GPs, we invite you to [open a
discussion](https://github.com/JaxGaussianProcesses/GPJax/discussions).

Feel free to join our [Slack
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw),
where we can discuss the development of GPJax and broader support for Gaussian
process modelling.

# Supported methods and interfaces

Expand All @@ -36,20 +53,21 @@ Feel free to join our [Slack Channel](https://join.slack.com/t/gpjax/shared_invi
> - [**Sparse Variational Inference**](https://gpjax.readthedocs.io/en/latest/examples/uncollapsed_vi.html)
> - [**BlackJax Integration**](https://gpjax.readthedocs.io/en/latest/examples/classification.html)
> - [**Laplace Approximation**](https://gpjax.readthedocs.io/en/latest/examples/classification.html#Laplace-approximation)
> - [**TensorFlow Probability Integration**](https://gpjax.readthedocs.io/en/latest/examples/tfp_integration.html)
> - [**Inference on Non-Euclidean Spaces**](https://gpjax.readthedocs.io/en/latest/examples/kernels.html#Custom-Kernel)
> - [**Inference on Graphs**](https://gpjax.readthedocs.io/en/latest/examples/graph_kernels.html)
> - [**Learning Gaussian Process Barycentres**](https://gpjax.readthedocs.io/en/latest/examples/barycentres.html)
> - [**Deep Kernel Regression**](https://gpjax.readthedocs.io/en/latest/examples/haiku.html)
> - [**Natural Gradients**](https://gpjax.readthedocs.io/en/latest/examples/natgrads.html)

## Guides for customisation
>
> - [**Custom kernels**](https://gpjax.readthedocs.io/en/latest/examples/kernels.html#Custom-Kernel)
> - [**UCI regression**](https://gpjax.readthedocs.io/en/latest/examples/yacht.html)

## Conversion between `.ipynb` and `.py`
Above examples are stored in [examples](examples) directory in the double percent (`py:percent`) format. Checkout [jupytext using-cli](https://jupytext.readthedocs.io/en/latest/using-cli.html) for more info.
Above examples are stored in [examples](examples) directory in the double
percent (`py:percent`) format. Checkout [jupytext
using-cli](https://jupytext.readthedocs.io/en/latest/using-cli.html) for more
info.

* To convert `example.py` to `example.ipynb`, run:

Expand All @@ -72,7 +90,6 @@ import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import gpjax.kernels as jk
import optax as ox

key = jr.PRNGKey(123)
Expand All @@ -83,70 +100,50 @@ n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)
```

The function of interest here, $f(\cdot)$, is sinusoidal, but our observations of it have been perturbed by Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

## 1. Constructing the prior and posterior
# Construct the prior
meanf = gpx.neab.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.Prior(mean_function=meanf, kernel = kernel)

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

```python
prior = gpx.Prior(kernel = jk.RBF())
# Define a likelihood
likelihood = gpx.Gaussian(num_datapoints = n)
```

Similar to how we would write on paper, the posterior is constructed by the product of our prior with our likelihood.

```python
# Construct the posterior
posterior = prior * likelihood
```

## 2. Learning hyperparameters
# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's [just-in-time (JIT)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) compilation to accelerate training.
# Define the marginal log-likelihood
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

```python
mll = jit(posterior.marginal_log_likelihood(D, negative=True))
```
# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=optimiser,
num_iters=500,
safe=True
)

For purposes of optimisation, we'll use optax's Adam.
```
opt = ox.adam(learning_rate=1e-3)
```

We define an initial parameter state through the `initialise` callable.

```python
parameter_state = gpx.initialise(posterior, key=key)
```

Finally, we run an optimisation loop using the Adam optimiser via the `fit` callable.

```python
inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)
```

## 3. Making predictions

Using our learned hyperparameters, we can obtain the posterior distribution of the latent function at novel test points.

```python
learned_params, _ = inference_state.unpack()
# Infer the predictive posterior distribution
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_dist = opt_posterior(xtest, D)
predictive_dist = opt_posterior.likelihood(latent_dist)

latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(learned_params, latent_distribution)

predictive_mean = predictive_distribution.mean()
predictive_cov = predictive_distribution.covariance()
# Obtain the predictive mean and standard deviation
pred_mean = predictive_dist.mean()
pred_std = predictive_dist.stddev()
```

# Installation

## Stable version

The latest stable version of GPJax can be installed via [`pip`](https://pip.pypa.io/en/stable/):
The latest stable version of GPJax can be installed via
pip:

```bash
pip install gpjax
Expand All @@ -166,11 +163,12 @@ pip install gpjax
>
> This version is possibly unstable and may contain bugs.

Clone a copy of the repository to your local machine and run the setup configuration in development mode.
Clone a copy of the repository to your local machine and run the setup
configuration in development mode.
```bash
git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
python setup.py develop
poetry install
```

> **Note**
Expand All @@ -184,7 +182,7 @@ python setup.py develop
> and recommend you check your installation passes the supplied unit tests:
>
> ```python
> python -m pytest tests/
> poetry run pytest
> ```

# Citing GPJax
Expand Down