Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Adding tagged parameters and updated notebooks #452

Merged
merged 14 commits into from
Jun 26, 2024
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --all-extras --with docs
poetry install --with docs

# Run the unit tests and build the coverage report
- name: Run Integration Tests
Expand Down
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ repos:
language: system
types: [python]
exclude: examples/
- repo: https://github.com/econchick/interrogate
rev: 1.5.0
hooks:
- id: interrogate
args:
[
"gpjax",
"--config",
"pyproject.toml",
]
pass_filenames: false
# - repo: https://github.com/econchick/interrogate
# rev: 1.5.0
# hooks:
# - id: interrogate
# args:
# [
# "gpjax",
# "--config",
# "pyproject.toml",
# ]
# pass_filenames: false
17 changes: 17 additions & 0 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Gaussian Processes Barycentres
#
Expand Down
37 changes: 26 additions & 11 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

config.update("jax_enable_x64", True)

from dataclasses import dataclass

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand Down Expand Up @@ -84,11 +81,11 @@

meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
for k, ax, c in zip(kernels, axes.ravel(), cols):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
ax.plot(x, y.T, alpha=0.7, color=c)
ax.set_title(k.name)

# %% [markdown]
Expand Down Expand Up @@ -205,24 +202,42 @@


# %%
from gpjax.kernels.computations import DenseKernelComputation
from gpjax.parameters import DEFAULT_BIJECTION, Static, PositiveReal


def angular_distance(x, y, c):
return jnp.abs((x - y + c) % (c * 2) - c)


bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))

DEFAULT_BIJECTION["polar"] = bij


@dataclass
class Polar(gpx.kernels.AbstractKernel):
period: float = static_field(2 * jnp.pi)
tau: float = param_field(jnp.array([5.0]), bijector=bij)
period: Static
tau: PositiveReal

def __init__(
self,
tau: float = 5.0,
period: float = 2 * jnp.pi,
active_dims: list[int] | slice | None = None,
n_dims: int | None = None,
):
super().__init__(active_dims, n_dims, DenseKernelComputation())
self.period = Static(jnp.array(period))
self.tau = PositiveReal(jnp.array(tau), tag="polar")

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
c = self.period / 2.0
c = self.period.value / 2.0
t = angular_distance(x, y, c)
K = (1 + self.tau * t / c) * jnp.clip(1 - t / c, 0, jnp.inf) ** self.tau
K = (1 + self.tau.value * t / c) * jnp.clip(
1 - t / c, 0, jnp.inf
) ** self.tau.value
return K.squeeze()


Expand Down Expand Up @@ -265,7 +280,7 @@ def __call__(
# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
objective=jit(gpx.objectives.ConjugateMLL(negative=True)),
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
)

Expand Down
78 changes: 46 additions & 32 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Deep Kernel Learning
#
Expand All @@ -18,10 +35,12 @@
dataclass,
field,
)
from typing import Any

import flax
from flax import linen as nn
from flax.experimental import nnx
from gpjax.kernels.computations import (
AbstractKernelComputation,
DenseKernelComputation,
)
import jax
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -95,25 +114,17 @@
# %%
@dataclass
class DeepKernelFunction(AbstractKernel):
base_kernel: AbstractKernel = None
network: nn.Module = static_field(None)
dummy_x: jax.Array = static_field(None)
key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123))
nn_params: Any = field(init=False, repr=False)

def __post_init__(self):
if self.base_kernel is None:
raise ValueError("base_kernel must be specified")
if self.network is None:
raise ValueError("network must be specified")
self.nn_params = flax.core.unfreeze(self.network.init(key, self.dummy_x))
base_kernel: AbstractKernel
network: nnx.Module
compute_engine: AbstractKernelComputation = field(
default_factory=lambda: DenseKernelComputation()
)

def __call__(
self, x: Float[Array, " D"], y: Float[Array, " D"]
) -> Float[Array, "1"]:
state = self.network.init(self.key, x)
xt = self.network.apply(state, x)
yt = self.network.apply(state, y)
xt = self.network(x)
yt = self.network(y)
return self.base_kernel(xt, yt)


Expand All @@ -135,20 +146,25 @@ def __call__(
feature_space_dim = 3


class Network(nn.Module):
"""A simple MLP."""
class Network(nnx.Module):
def __init__(
self, rngs: nnx.Rngs, *, input_dim: int, inner_dim: int, feature_space_dim: int
) -> None:
self.layer1 = nnx.Linear(input_dim, inner_dim, rngs=rngs)
self.output_layer = nnx.Linear(inner_dim, feature_space_dim, rngs=rngs)
self.rngs = rngs

@nn.compact
def __call__(self, x):
x = nn.Dense(features=32)(x)
x = nn.relu(x)
x = nn.Dense(features=64)(x)
x = nn.relu(x)
x = nn.Dense(features=feature_space_dim)(x)
def __call__(self, x: jax.Array) -> jax.Array:
x = x.reshape((x.shape[0], -1))
x = self.layer1(x)
x = jax.nn.relu(x)
x = self.output_layer(x).squeeze()
return x


forward_linear = Network()
forward_linear = Network(
nnx.Rngs(123), feature_space_dim=feature_space_dim, inner_dim=32, input_dim=1
)

# %% [markdown]
# ## Defining a model
Expand All @@ -162,9 +178,7 @@ def __call__(self, x):
active_dims=list(range(feature_space_dim)),
lengthscale=jnp.ones((feature_space_dim,)),
)
kernel = DeepKernelFunction(
network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x
)
kernel = DeepKernelFunction(network=forward_linear, base_kernel=base_kernel)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
Expand Down Expand Up @@ -202,7 +216,7 @@ def __call__(self, x):

opt_posterior, history = gpx.fit(
model=posterior,
objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)),
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
optim=optimiser,
num_iters=800,
Expand Down
19 changes: 18 additions & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Graph Kernels
#
Expand Down Expand Up @@ -154,7 +171,7 @@
# %%
opt_posterior, training_history = gpx.fit_scipy(
model=posterior,
objective=gpx.objectives.ConjugateMLL(negative=True),
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
)

Expand Down
36 changes: 28 additions & 8 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Introduction to Kernels

Expand Down Expand Up @@ -213,6 +230,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
# First we define our model, using the Matérn52 kernel, and construct our posterior *without* optimising the kernel hyperparameters:

# %%
from gpjax.parameters import PositiveReal

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52(
lengthscale=jnp.array(0.1)
Expand All @@ -221,24 +240,22 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

likelihood = gpx.likelihoods.Gaussian(
num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
num_datapoints=D.n, obs_stddev=PositiveReal(value=jnp.array(1e-3), tag="Static")
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)

no_opt_posterior = prior * likelihood

# %% [markdown]
# We can then optimise the hyperparameters by minimising the negative log marginal likelihood of the data:

# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)
gpx.objectives.conjugate_mll(no_opt_posterior, data=D)


# %%
opt_posterior, history = gpx.fit_scipy(
model=no_opt_posterior,
objective=negative_mll,
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
)

Expand Down Expand Up @@ -499,17 +516,20 @@ def plot_ribbon(ax, x, dist, color):

posterior = prior * likelihood


# %% [markdown]
# With our model constructed, let's now fit it to the data, by minimising the negative log
# marginal likelihood of the data:


# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)
def loss(posterior, data):
return -gpx.objectives.conjugate_mll(posterior, data)


opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
objective=loss,
train_data=D,
optim=ox.adamw(learning_rate=1e-2),
num_iters=500,
Expand Down
Loading
Loading