Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paper edits for JOSS #75

Merged
merged 2 commits into from
Jun 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ Gaussian processes (GPs) [@rasmussen2006gaussian] are Bayesian nonparametric mod

`GPJax` has been carefully tailored to amalgamate with the Jax ecosystem. For efficient Markov Chain Monte Carlo inference, `GPJax` can utilise samplers from BlackJax [@blackjax2021] and TensorFlow Probability [@abadi2016tensorflow]. For gradient-based optimisation, `GPJax` integrates seamlessly with Optax [@deepmind2020jax], providing a vast suite of optimisers and learning rate schedules. To efficiently represent probability distributions, `GPJax` leverages Distrax [@deepmind2020jax] and TensorFlow Probability [@abadi2016tensorflow]. To combine GPs with deep learning methods, `GPJax` can incorporate the functionality provided within Haiku [@deepmind2020jax]. The `GPJax` documentation includes examples of each of these integrations.

The foundation of each abstraction given in `GPJax` is a Chex [deepmind2020jax] dataclass object. These require significantly less boilerplate code than regular Python classes, leading to a more readable codebase. Moreover, Chex dataclasses are registered as PyTree nodes, facilitating the applications of Jax operations such as just-in-time compilation and automatic differentiation to any `GPJax` object.
The foundation of each abstraction given in `GPJax` is a Chex [@deepmind2020jax] dataclass object. These require significantly less boilerplate code than regular Python classes, leading to a more readable codebase. Moreover, Chex dataclasses are registered as PyTree nodes, facilitating the applications of Jax operations such as just-in-time compilation and automatic differentiation to any `GPJax` object.

The intimacy between `GPJax` and the underlying maths also makes `GPJax` an excellent package for people new to GP modelling. Having the ability to easily cross-reference the contents of a textbook with the code that one is writing is invaluable when trying to build an intuition for a new statistical method. We further support this effort in `GPJax` through documentation that provides detailed explanations of the operations conducted within each notebook.

# Wider Software Ecosystem

From both an applied and methodological perspective, GPs are widely employed in the statistics and machine learning communities. High-quality software packages that promote GP modelling are accountable for much of their success. Within the Python community, the three most popular packages for GP modelling are GPFlow [@matthews2017gpflow], GPyTorch [@gardner2018gpytorch], and GPy [@gpy2014]. Despite each of these packages being indispensable tools for the community, none of them supports integration with a Jax-based workflow. `GPJax` seeks to resolve this issue. Furthermore, modern research from the GP literature, graph kernels [@borovitskiy2021matern] and Wasserstein barycentres for GPs [@mallasto2017learning], for example, are supported within `GPJax` but absent from all other packages.
From both an applied and methodological perspective, GPs are widely employed in the statistics and machine learning communities. High-quality software packages that promote GP modelling are accountable for much of their success. Within the Python community, the three most popular packages for GP modelling are GPFlow [@matthews2017gpflow], GPyTorch [@gardner2018gpytorch], and GPy [@gpy2014]. Despite each of these packages being indispensable tools for the community, none of them support integration with a Jax-based workflow. `GPJax` seeks to resolve this issue. Furthermore, modern research from the GP literature, graph kernels [@borovitskiy2021matern] and Wasserstein barycentres for GPs [@mallasto2017learning], for example, are supported within `GPJax` but absent from these other packages.

For completeness, packages written for languages other than Python include GPML [@rasmussen2010gaussian] in Matlab, GaussianProcesses.jl [@fairbrother2022gaussianprocesses], AugmentedGaussianProcesses.jl [@fajou20a] and Stheno [@stheno2022tebbutt] all in Julia.

Expand Down