diff --git a/jaxkern/approximations/rff.py b/jaxkern/approximations/rff.py index d1065e5..9062e9e 100644 --- a/jaxkern/approximations/rff.py +++ b/jaxkern/approximations/rff.py @@ -1,4 +1,4 @@ -from ..base import AbstractKernel +from ..base import AbstractKernel, StationaryKernel from ..computations import BasisFunctionComputation from jax.random import KeyArray from typing import Any @@ -66,9 +66,15 @@ def _check_valid_base_kernel(self, kernel: AbstractKernel): Args: kernel (AbstractKernel): The kernel to be checked. """ - error_msg = """ - Base kernel must have a spectral density. Currently, only Matérn - and RBF kernels have implemented spectral densities. - """ - if kernel.spectral_density is None: - raise ValueError(error_msg) + if not isinstance(kernel, StationaryKernel): + raise TypeError( + f"""Random Fourier Features are only defined for stationary kernels. + {kernel.name} is non-stationary.""" + ) + else: + if kernel.spectral_density is None: + error_msg = """ + Base kernel must have a spectral density. Currently, only Matérn + and RBF kernels have implemented spectral densities. + """ + raise ValueError(error_msg) diff --git a/jaxkern/base.py b/jaxkern/base.py index 11a0f73..f6c6e27 100644 --- a/jaxkern/base.py +++ b/jaxkern/base.py @@ -14,9 +14,8 @@ # ============================================================================== import abc -from typing import Callable, Dict, List, Optional, Sequence +from typing import Callable, List, Optional, Sequence -import deprecation import jax.numpy as jnp import jax.random import jax @@ -40,27 +39,20 @@ def __init__( self, compute_engine: AbstractKernelComputation = DenseKernelComputation, active_dims: Optional[List[int]] = None, - spectral_density: Optional[dx.Distribution] = None, - name: Optional[str] = "AbstractKernel", + name: Optional[str] = "Abstrac tKernel", ) -> None: self._compute_engine = compute_engine self.active_dims = active_dims - self.spectral_density = spectral_density self.name = name - self._stationary = False self.ndims = 1 if not self.active_dims else len(self.active_dims) compute_engine = self.compute_engine(kernel_fn=self.__call__) self.gram = compute_engine.gram self.cross_covariance = compute_engine.cross_covariance + self._spectral_density = None @property - def stationary(self) -> bool: - """Boolean property as to whether the kernel is stationary or not. - - Returns: - bool: True if the kernel is stationary. - """ - return self._stationary + def spectral_density(self) -> dx.Distribution: + return self._spectral_density @property def compute_engine(self) -> AbstractKernelComputation: @@ -81,14 +73,14 @@ def compute_engine(self, compute_engine: AbstractKernelComputation) -> None: @abc.abstractmethod def __call__( self, - params: Dict, + params: Parameters, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs. Args: - params (Dict): Parameter set for which the kernel should be evaluated on. + params (Parameters): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call @@ -147,25 +139,11 @@ def init_params(self, key: KeyArray) -> Parameters: the kernel's parameters. Returns: - Dict: A dictionary of the kernel's parameters. - """ - raise NotImplementedError - - @deprecation.deprecated( - deprecated_in="0.0.3", - removed_in="0.1.0", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set. - - Args: - key (KeyArray): A PRNG key to be used for initialising - the kernel's parameters. - - Returns: - Dict: A dictionary of the kernel's parameters. + Parameters: A dictionary of the kernel's parameters. """ - raise NotImplementedError + raise NotImplementedError( + f"`init_params` not implemented for {self.name} kernel" + ) class CombinationKernel(AbstractKernel): @@ -185,8 +163,6 @@ def __init__( if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): raise TypeError("can only combine Kernel instances") # pragma: no cover - if all(k.stationary for k in self.kernel_set): - self._stationary = True self._set_kernels(self.kernel_set) def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: @@ -212,14 +188,14 @@ def init_params(self, key: KeyArray) -> Parameters: def __call__( self, - params: Dict, + params: Parameters, x: Float[Array, "1 D"], y: Float[Array, "1 D"], ) -> Float[Array, "1"]: """Evaluate combination kernel on a pair of inputs. Args: - params (Dict): Parameter set for which the kernel should be evaluated on. + params (Parameters): Parameter set for which the kernel should be evaluated on. x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. y (Float[Array, "1 D"]): The right hand argument of the kernel function's call @@ -257,3 +233,19 @@ def __init__( ) -> None: super().__init__(kernel_set, compute_engine, active_dims, name) self.combination_fn: Optional[Callable] = jnp.prod + + +class StationaryKernel(AbstractKernel): + """ + Stationary kernels are subclass of kernels that are translation invariant. + For a kernel :math:`k(x, y)` to be stationary, it must be a function of + :math:k(x, y) = k(x - y)`. + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + name: Optional[str] = "Stationary Kernel", + ) -> None: + super().__init__(compute_engine, active_dims, name) diff --git a/jaxkern/non_euclidean/graph.py b/jaxkern/non_euclidean/graph.py index e97e71f..7307844 100644 --- a/jaxkern/non_euclidean/graph.py +++ b/jaxkern/non_euclidean/graph.py @@ -49,7 +49,6 @@ def __init__( super().__init__( EigenKernelComputation, active_dims, - spectral_density=None, name=name, ) self.laplacian = laplacian diff --git a/jaxkern/nonstationary/linear.py b/jaxkern/nonstationary/linear.py index e0f5e42..9cac307 100644 --- a/jaxkern/nonstationary/linear.py +++ b/jaxkern/nonstationary/linear.py @@ -42,7 +42,6 @@ def __init__( super().__init__( DenseKernelComputation, active_dims, - spectral_density=None, name=name, ) self._stationary = False diff --git a/jaxkern/nonstationary/polynomial.py b/jaxkern/nonstationary/polynomial.py index eb473c8..b3e088b 100644 --- a/jaxkern/nonstationary/polynomial.py +++ b/jaxkern/nonstationary/polynomial.py @@ -33,13 +33,11 @@ def __init__( self, degree: int = 1, active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, name: Optional[str] = "Polynomial", ) -> None: super().__init__( DenseKernelComputation, active_dims, - spectral_density=None, name=name, ) self.degree = degree diff --git a/jaxkern/stationary/matern12.py b/jaxkern/stationary/matern12.py index 76f5572..64c48c9 100644 --- a/jaxkern/stationary/matern12.py +++ b/jaxkern/stationary/matern12.py @@ -20,14 +20,14 @@ from jaxtyping import Array, Float from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) from .utils import euclidean_distance, build_student_t_distribution -class Matern12(AbstractKernel): +class Matern12(StationaryKernel): """The Matérn kernel with smoothness parameter fixed at 0.5.""" def __init__( @@ -35,9 +35,8 @@ def __init__( active_dims: Optional[List[int]] = None, name: Optional[str] = "Matérn 1/2 kernel", ) -> None: - spectral_density = build_student_t_distribution(nu=1) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) - self._stationary = True + super().__init__(DenseKernelComputation, active_dims, name) + self._spectral_density = build_student_t_distribution(1.0) def __call__( self, diff --git a/jaxkern/stationary/matern32.py b/jaxkern/stationary/matern32.py index 910cc57..26ad4d5 100644 --- a/jaxkern/stationary/matern32.py +++ b/jaxkern/stationary/matern32.py @@ -20,14 +20,14 @@ from jaxtyping import Array, Float from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) from .utils import euclidean_distance, build_student_t_distribution -class Matern32(AbstractKernel): +class Matern32(StationaryKernel): """The Matérn kernel with smoothness parameter fixed at 1.5.""" def __init__( @@ -35,9 +35,8 @@ def __init__( active_dims: Optional[List[int]] = None, name: Optional[str] = "Matern 3/2", ) -> None: - spectral_density = build_student_t_distribution(nu=3) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) - self._stationary = True + super().__init__(DenseKernelComputation, active_dims, name) + self._spectral_density = build_student_t_distribution(nu=3) def __call__( self, diff --git a/jaxkern/stationary/matern52.py b/jaxkern/stationary/matern52.py index 2c3a176..0c42974 100644 --- a/jaxkern/stationary/matern52.py +++ b/jaxkern/stationary/matern52.py @@ -19,14 +19,14 @@ from jax.random import KeyArray from jaxtyping import Array, Float from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) from .utils import euclidean_distance, build_student_t_distribution -class Matern52(AbstractKernel): +class Matern52(StationaryKernel): """The Matérn kernel with smoothness parameter fixed at 2.5.""" def __init__( @@ -34,8 +34,8 @@ def __init__( active_dims: Optional[List[int]] = None, name: Optional[str] = "Matern 5/2", ) -> None: - spectral_density = build_student_t_distribution(nu=5) - super().__init__(DenseKernelComputation, active_dims, spectral_density, name) + super().__init__(DenseKernelComputation, active_dims, name) + self._spectral_density = build_student_t_distribution(nu=5) def __call__( self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] diff --git a/jaxkern/stationary/periodic.py b/jaxkern/stationary/periodic.py index 0877ef2..7783d2a 100644 --- a/jaxkern/stationary/periodic.py +++ b/jaxkern/stationary/periodic.py @@ -21,13 +21,13 @@ from jaxtyping import Array from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) -class Periodic(AbstractKernel): +class Periodic(StationaryKernel): """The periodic kernel. Key reference is MacKay 1998 - "Introduction to Gaussian processes". @@ -41,7 +41,6 @@ def __init__( super().__init__( DenseKernelComputation, active_dims, - spectral_density=None, name=name, ) self._stationary = True diff --git a/jaxkern/stationary/powered_exponential.py b/jaxkern/stationary/powered_exponential.py index 644b3a6..f7220cc 100644 --- a/jaxkern/stationary/powered_exponential.py +++ b/jaxkern/stationary/powered_exponential.py @@ -20,14 +20,14 @@ from jax.random import KeyArray from jaxtyping import Array from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) from .utils import euclidean_distance -class PoweredExponential(AbstractKernel): +class PoweredExponential(StationaryKernel): """The powered exponential family of kernels. Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". @@ -42,7 +42,6 @@ def __init__( super().__init__( DenseKernelComputation, active_dims, - spectral_density=None, name=name, ) self._stationary = True diff --git a/jaxkern/stationary/rational_quadratic.py b/jaxkern/stationary/rational_quadratic.py index 41beb21..0cb0ea3 100644 --- a/jaxkern/stationary/rational_quadratic.py +++ b/jaxkern/stationary/rational_quadratic.py @@ -20,14 +20,14 @@ from jax.random import KeyArray from jaxtyping import Array from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) from .utils import squared_distance -class RationalQuadratic(AbstractKernel): +class RationalQuadratic(StationaryKernel): def __init__( self, active_dims: Optional[List[int]] = None, @@ -36,7 +36,6 @@ def __init__( super().__init__( DenseKernelComputation, active_dims, - spectral_density=None, name=name, ) self._stationary = True diff --git a/jaxkern/stationary/rbf.py b/jaxkern/stationary/rbf.py index 965aa4f..98b5ce6 100644 --- a/jaxkern/stationary/rbf.py +++ b/jaxkern/stationary/rbf.py @@ -20,7 +20,7 @@ from jaxtyping import Array, Float from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( DenseKernelComputation, ) @@ -28,7 +28,7 @@ import distrax as dx -class RBF(AbstractKernel): +class RBF(StationaryKernel): """The Radial Basis Function (RBF) kernel.""" def __init__( @@ -39,10 +39,9 @@ def __init__( super().__init__( DenseKernelComputation, active_dims, - spectral_density=dx.Normal(loc=0.0, scale=1.0), name=name, ) - self._stationary = True + self._spectral_density = dx.Normal(loc=0.0, scale=1.0) def __call__( self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] diff --git a/jaxkern/stationary/white.py b/jaxkern/stationary/white.py index 6b98d1d..3f3dc6f 100644 --- a/jaxkern/stationary/white.py +++ b/jaxkern/stationary/white.py @@ -18,22 +18,22 @@ import jax.numpy as jnp from jaxtyping import Array, Float from jaxutils import Parameters, Softplus -from ..base import AbstractKernel +from ..base import StationaryKernel from ..computations import ( ConstantDiagonalKernelComputation, AbstractKernelComputation, ) -class White(AbstractKernel): +class White(StationaryKernel): def __init__( self, compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, active_dims: Optional[List[int]] = None, name: Optional[str] = "White Noise Kernel", ) -> None: - super().__init__(compute_engine, active_dims, spectral_density=None, name=name) - self._stationary = True + super().__init__(compute_engine, active_dims, name=name) + self._spectral_density = None def __call__( self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] diff --git a/tests/test_approximations.py b/tests/test_approximations.py index 284a4f3..0cf1dda 100644 --- a/tests/test_approximations.py +++ b/tests/test_approximations.py @@ -8,6 +8,7 @@ RationalQuadratic, PoweredExponential, Periodic, + White, ) from jaxkern.nonstationary import Polynomial, Linear from jaxkern.base import AbstractKernel @@ -142,7 +143,16 @@ def test_exactness(kernel): @pytest.mark.parametrize( "kernel", - [RationalQuadratic, PoweredExponential, Polynomial, Linear, Periodic], + [Linear, Polynomial], +) +def test_type_error(kernel): + with pytest.raises(TypeError): + RFF(kernel(), num_basis_fns=10) + + +@pytest.mark.parametrize( + "kernel", + [RationalQuadratic, PoweredExponential, White, Periodic], ) def test_value_error(kernel): with pytest.raises(ValueError): diff --git a/tests/test_stationary.py b/tests/test_stationary.py index 5b9a64e..9e43c00 100644 --- a/tests/test_stationary.py +++ b/tests/test_stationary.py @@ -24,7 +24,7 @@ from jax.config import config from jaxlinop import LinearOperator, identity -from jaxkern.base import AbstractKernel +from jaxkern.base import AbstractKernel, StationaryKernel from jaxkern.stationary import ( RBF, Matern12, @@ -72,6 +72,21 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: assert Kxx.shape == (n, n) +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + RationalQuadratic(), + White(), + ], +) +def test_stationarity(kernel: AbstractKernel) -> None: + assert isinstance(kernel, StationaryKernel) + + @pytest.mark.parametrize( "kernel", [