Skip to content

Commit

Permalink
Merge pull request #77 from thomaspinder/joss_response
Browse files Browse the repository at this point in the history
Joss response
  • Loading branch information
thomaspinder authored Jun 15, 2022
2 parents 4a813c9 + 2d3d8e4 commit 97ab331
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 4 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@

GPJax aims to provide a low-level interface to Gaussian process (GP) models in [Jax](https://github.com/google/jax), structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

## Package support

GPJax was created by [Thomas Pinder](https://github.com/thomaspinder). Today, the maintenance of GPJax is undertaken by Thomas and [Daniel Dodd](https://github.com/Daniel-Dodd).

We would be delighted to review pull requests (PRs) from new contributors. Before contributing, please read our [guide for contributing](https://github.com/thomaspinder/GPJax/blob/master/CONTRIBUTING.md). If you do not have the capacity to open a PR, or you would like guidance on how best to structure a PR, then please [open an issue](https://github.com/thomaspinder/GPJax/issues/new/choose). For broader discussions on best practices for fitting GPs, technical questions surrounding the mathematics of GPs, or anything else that you feel doesn't quite constitue an issue, please start a discussion thread in our [discussion tracker](https://github.com/thomaspinder/GPJax/discussions).

## Supported methods and interfaces

### Examples
Expand Down
15 changes: 14 additions & 1 deletion docs/latex_symbols.tex
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
\newcommand{\bbR}{\mathbb{R}}
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\DeclareMathOperator*{\argmin}{arg\,min}

\newcommand{\x}{\mathbf{x}}
\newcommand{\X}{\mathbf{X}}
\newcommand{\xx}{\mathbf{x}^{\star}}
\newcommand{\Xx}{\mathbf{X}^{\star}}
\newcommand{\Y}{\mathbf{Y}}
\newcommand{\y}{\mathbf{y}}
\newcommand{\by}{\mathbf{y}}
\newcommand{\chol}{\operatorname{chol}}
\newcommand{\GP}{\mathcal{GP}}

\newcommand{\Kff}{\mathbf{K}_{\mathbf{ff}}}
\newcommand{\IdentityMatrix}{\mathbf{I}}
9 changes: 7 additions & 2 deletions docs/nbs/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
params = gpx.transform(params, unconstrainer)

mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))

# %% [markdown]
# We can obtain a MAP estimate by optimising the marginal log-likelihood with Obtax's optimisers.
# %%
Expand All @@ -79,6 +80,9 @@
opt,
n_iters=500,
)
ps = gpx.transform(map_estimate, constrainer)
latent_dist = posterior(D, ps)(xtest)
predictive_dist = likelihood(latent_dist, ps)
# %% [markdown]
# However, as a point estimate, MAP estimation is severely limited for uncertainty quantification, providing only a single piece of information about the posterior. On the other hand, through approximate sampling, MCMC methods allow us to learn all information about the posterior distribution.
# %% [markdown]
Expand All @@ -95,7 +99,7 @@
# We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary.
# %%
# Adapted from BlackJax's introduction notebook.
num_adapt = 1000
num_adapt = 500
num_samples = 500

mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))
Expand Down Expand Up @@ -156,7 +160,8 @@ def one_step(state, rng_key):
ps["latent"] = states.position["latent"][i, :, :]
ps = gpx.transform(ps, constrainer)

predictive_dist = likelihood(posterior(D, ps)(xtest), ps)
latent_dist = posterior(D, ps)(xtest)
predictive_dist = likelihood(latent_dist, ps)
samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))

samples = jnp.vstack(samples)
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
networkx
pytest-cov
pytest-cov
pre-commit

0 comments on commit 97ab331

Please sign in to comment.