Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Subclass stationary kernels #52

Merged
merged 1 commit into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions jaxkern/approximations/rff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..base import AbstractKernel
from ..base import AbstractKernel, StationaryKernel
from ..computations import BasisFunctionComputation
from jax.random import KeyArray
from typing import Any
Expand Down Expand Up @@ -66,9 +66,15 @@ def _check_valid_base_kernel(self, kernel: AbstractKernel):
Args:
kernel (AbstractKernel): The kernel to be checked.
"""
error_msg = """
Base kernel must have a spectral density. Currently, only Matérn
and RBF kernels have implemented spectral densities.
"""
if kernel.spectral_density is None:
raise ValueError(error_msg)
if not isinstance(kernel, StationaryKernel):
raise TypeError(
f"""Random Fourier Features are only defined for stationary kernels.
{kernel.name} is non-stationary."""
)
else:
if kernel.spectral_density is None:
error_msg = """
Base kernel must have a spectral density. Currently, only Matérn
and RBF kernels have implemented spectral densities.
"""
raise ValueError(error_msg)
66 changes: 29 additions & 37 deletions jaxkern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# ==============================================================================

import abc
from typing import Callable, Dict, List, Optional, Sequence
from typing import Callable, List, Optional, Sequence

import deprecation
import jax.numpy as jnp
import jax.random
import jax
Expand All @@ -40,27 +39,20 @@ def __init__(
self,
compute_engine: AbstractKernelComputation = DenseKernelComputation,
active_dims: Optional[List[int]] = None,
spectral_density: Optional[dx.Distribution] = None,
name: Optional[str] = "AbstractKernel",
name: Optional[str] = "Abstrac tKernel",
) -> None:
self._compute_engine = compute_engine
self.active_dims = active_dims
self.spectral_density = spectral_density
self.name = name
self._stationary = False
self.ndims = 1 if not self.active_dims else len(self.active_dims)
compute_engine = self.compute_engine(kernel_fn=self.__call__)
self.gram = compute_engine.gram
self.cross_covariance = compute_engine.cross_covariance
self._spectral_density = None

@property
def stationary(self) -> bool:
"""Boolean property as to whether the kernel is stationary or not.

Returns:
bool: True if the kernel is stationary.
"""
return self._stationary
def spectral_density(self) -> dx.Distribution:
return self._spectral_density

@property
def compute_engine(self) -> AbstractKernelComputation:
Expand All @@ -81,14 +73,14 @@ def compute_engine(self, compute_engine: AbstractKernelComputation) -> None:
@abc.abstractmethod
def __call__(
self,
params: Dict,
params: Parameters,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs.

Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
params (Parameters): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call

Expand Down Expand Up @@ -147,25 +139,11 @@ def init_params(self, key: KeyArray) -> Parameters:
the kernel's parameters.

Returns:
Dict: A dictionary of the kernel's parameters.
"""
raise NotImplementedError

@deprecation.deprecated(
deprecated_in="0.0.3",
removed_in="0.1.0",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""A template dictionary of the kernel's parameter set.

Args:
key (KeyArray): A PRNG key to be used for initialising
the kernel's parameters.

Returns:
Dict: A dictionary of the kernel's parameters.
Parameters: A dictionary of the kernel's parameters.
"""
raise NotImplementedError
raise NotImplementedError(
f"`init_params` not implemented for {self.name} kernel"
)


class CombinationKernel(AbstractKernel):
Expand All @@ -185,8 +163,6 @@ def __init__(

if not all(isinstance(k, AbstractKernel) for k in self.kernel_set):
raise TypeError("can only combine Kernel instances") # pragma: no cover
if all(k.stationary for k in self.kernel_set):
self._stationary = True
self._set_kernels(self.kernel_set)

def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None:
Expand All @@ -212,14 +188,14 @@ def init_params(self, key: KeyArray) -> Parameters:

def __call__(
self,
params: Dict,
params: Parameters,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
"""Evaluate combination kernel on a pair of inputs.

Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
params (Parameters): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call

Expand Down Expand Up @@ -257,3 +233,19 @@ def __init__(
) -> None:
super().__init__(kernel_set, compute_engine, active_dims, name)
self.combination_fn: Optional[Callable] = jnp.prod


class StationaryKernel(AbstractKernel):
"""
Stationary kernels are subclass of kernels that are translation invariant.
For a kernel :math:`k(x, y)` to be stationary, it must be a function of
:math:k(x, y) = k(x - y)`.
"""

def __init__(
self,
compute_engine: AbstractKernelComputation = DenseKernelComputation,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Stationary Kernel",
) -> None:
super().__init__(compute_engine, active_dims, name)
1 change: 0 additions & 1 deletion jaxkern/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(
super().__init__(
EigenKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self.laplacian = laplacian
Expand Down
1 change: 0 additions & 1 deletion jaxkern/nonstationary/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self._stationary = False
Expand Down
2 changes: 0 additions & 2 deletions jaxkern/nonstationary/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ def __init__(
self,
degree: int = 1,
active_dims: Optional[List[int]] = None,
stationary: Optional[bool] = False,
name: Optional[str] = "Polynomial",
) -> None:
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self.degree = degree
Expand Down
9 changes: 4 additions & 5 deletions jaxkern/stationary/matern12.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,23 @@
from jaxtyping import Array, Float
from jaxutils import Parameters, Softplus

from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution


class Matern12(AbstractKernel):
class Matern12(StationaryKernel):
"""The Matérn kernel with smoothness parameter fixed at 0.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matérn 1/2 kernel",
) -> None:
spectral_density = build_student_t_distribution(nu=1)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
self._stationary = True
super().__init__(DenseKernelComputation, active_dims, name)
self._spectral_density = build_student_t_distribution(1.0)

def __call__(
self,
Expand Down
9 changes: 4 additions & 5 deletions jaxkern/stationary/matern32.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,23 @@
from jaxtyping import Array, Float
from jaxutils import Parameters, Softplus

from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution


class Matern32(AbstractKernel):
class Matern32(StationaryKernel):
"""The Matérn kernel with smoothness parameter fixed at 1.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 3/2",
) -> None:
spectral_density = build_student_t_distribution(nu=3)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
self._stationary = True
super().__init__(DenseKernelComputation, active_dims, name)
self._spectral_density = build_student_t_distribution(nu=3)

def __call__(
self,
Expand Down
8 changes: 4 additions & 4 deletions jaxkern/stationary/matern52.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
from jax.random import KeyArray
from jaxtyping import Array, Float
from jaxutils import Parameters, Softplus
from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution


class Matern52(AbstractKernel):
class Matern52(StationaryKernel):
"""The Matérn kernel with smoothness parameter fixed at 2.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 5/2",
) -> None:
spectral_density = build_student_t_distribution(nu=5)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
super().__init__(DenseKernelComputation, active_dims, name)
self._spectral_density = build_student_t_distribution(nu=5)

def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Expand Down
5 changes: 2 additions & 3 deletions jaxkern/stationary/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from jaxtyping import Array

from jaxutils import Parameters, Softplus
from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)


class Periodic(AbstractKernel):
class Periodic(StationaryKernel):
"""The periodic kernel.

Key reference is MacKay 1998 - "Introduction to Gaussian processes".
Expand All @@ -41,7 +41,6 @@ def __init__(
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self._stationary = True
Expand Down
5 changes: 2 additions & 3 deletions jaxkern/stationary/powered_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from jax.random import KeyArray
from jaxtyping import Array
from jaxutils import Parameters, Softplus
from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance


class PoweredExponential(AbstractKernel):
class PoweredExponential(StationaryKernel):
"""The powered exponential family of kernels.

Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics".
Expand All @@ -42,7 +42,6 @@ def __init__(
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self._stationary = True
Expand Down
5 changes: 2 additions & 3 deletions jaxkern/stationary/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from jax.random import KeyArray
from jaxtyping import Array
from jaxutils import Parameters, Softplus
from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import squared_distance


class RationalQuadratic(AbstractKernel):
class RationalQuadratic(StationaryKernel):
def __init__(
self,
active_dims: Optional[List[int]] = None,
Expand All @@ -36,7 +36,6 @@ def __init__(
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=None,
name=name,
)
self._stationary = True
Expand Down
7 changes: 3 additions & 4 deletions jaxkern/stationary/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from jaxtyping import Array, Float
from jaxutils import Parameters, Softplus

from ..base import AbstractKernel
from ..base import StationaryKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import squared_distance
import distrax as dx


class RBF(AbstractKernel):
class RBF(StationaryKernel):
"""The Radial Basis Function (RBF) kernel."""

def __init__(
Expand All @@ -39,10 +39,9 @@ def __init__(
super().__init__(
DenseKernelComputation,
active_dims,
spectral_density=dx.Normal(loc=0.0, scale=1.0),
name=name,
)
self._stationary = True
self._spectral_density = dx.Normal(loc=0.0, scale=1.0)

def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Expand Down
Loading