Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Wrapper (instead of specializing kernel evaluation on DiffPt) #10

Merged
merged 6 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DifferentiableKernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Reexport

@reexport using KernelFunctions

export partial
export partial, EnableDiff

include("multiOutput.jl")
include("partial.jl")
Expand Down
62 changes: 41 additions & 21 deletions src/diffKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,51 @@ import LinearAlgebra as LA
using KernelFunctions: SimpleKernel, Kernel

"""
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {Dim, T<:Kernel}

implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
generics are not allowed in the syntax above by the dispatch system, this
redirection over `_evaluate` is necessary

unboxes the partial instructions from DiffPt and applies them to k,
evaluates them at the positions of DiffPt
specialization for DiffPt. Unboxes the partial instructions from DiffPt and
applies them to k, evaluates them at the positions of DiffPt
"""
function _evaluate(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
function diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
return apply_partial(k, px.indices, py.indices)(x, y)
end

#=
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
```julia
(::Kernel)(::DiffPt,::DiffPt)
```
then julia would not know whether to use
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
"""
EnableDiff

A thin wrapper around Kernels enabling the machinery which allows you to
input (x, ∂ᵢ), (y, ∂ⱼ) where ∂ᵢ, ∂ⱼ are of `Partial` type (see [partial](@ref)) in order
to calculate
``
k((x, ∂ᵢ), (y,∂ⱼ)) = \\text{Cov}(\\partial_i Z(x), \\partial_j Z(y))
``
for ``Z`` with ``k(x,y) = \\text{Cov}(Z(x), Z(y))``.

!!! warning Only apply this wrapper at the very end. Kerneltransformations
should be applied beforehand.

!!! info While this machinery could in principle be enabled for all `Kernel` by default,
the covariance of derivatives of an isotropic kernel are no longer isotropic.
This forces the use of less specialized methods. So for now you have to opt-in
with this Wrapper.

Example:

```jldoctest
julia> k = EnableDiff(SEKernel());

julia> k((0, partial(1)), 0) # calculate Cov(∂₁Z(0), Z(0))
0.0

julia> k(0,0) # normal input still works
1.0
```
=#
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
(k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y)
(k::T)(x::DiffPt, y) = _evaluate(k, x,(y, partial()))
(k::T)(x, y::DiffPt) = _evaluate(k, (x, partial()), y)
"""
struct EnableDiff{T<:Kernel} <: Kernel
kernel::T
end
(k::EnableDiff)(x::DiffPt, y::DiffPt) = diffKernelCall(k.kernel, x, y)
(k::EnableDiff)(x::DiffPt, y) = diffKernelCall(k.kernel, x,(y, partial()))
(k::EnableDiff)(x, y::DiffPt) = diffKernelCall(k.kernel, (x, partial()), y)
(k::EnableDiff)(x, y) = k.kernel(x,y) # Fall through case

2 changes: 1 addition & 1 deletion src/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,6 @@ i.e. 2*dim dimensional input
function apply_partial(
k, partials_x::Tuple{Vararg{T}}, partials_y::Tuple{Vararg{T}}
) where {T<:IndexType}
local f(x, y) = apply_partial(t -> k(t, y), partials_x...)(x)
f(x, y) = apply_partial(t -> k(t, y), partials_x...)(x)
return (x, y) -> apply_partial(t -> f(x, t), partials_y...)(y)
end
32 changes: 23 additions & 9 deletions test/diffKernel.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
@testset "diffKernel" begin
@testset "smoke test" begin
k = MaternKernel()
k(1, 1)
k = EnableDiff(MaternKernel())
k2 = MaternKernel()
@test k(1, 1) == k2(1, 1)
k(1, (1, partial(1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
k(([1], partial(1)), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
k(([1, 2], partial(1)), ([1, 2], partial(2)))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2]
end

@testset "Sanity Checks with $k" for k in [SEKernel()]
@testset "Sanity Checks with $k1" for k1 in [
SEKernel(),
MaternKernel(ν=5),
RationalQuadraticKernel(),
SEKernel() + RationalQuadraticKernel()
]
k = EnableDiff(k1)
for x in [0, 1, -1, 42]
# for stationary kernels Cov(∂Z(x) , Z(x)) = 0
@test k((x, partial(1)), x) ≈ 0
# correlation with self should be positive
## This fails for Matern and RationalQuadraticKernel
# because its implementation branches on x == y resulting in a zero derivative
# (cf. https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/517)
@test k((x, partial(1)), (x, partial(1))) > 0

# the slope should be positively correlated with a point further down
@test k(
(x, partial(1)), # slope
x + 1e-1, # point further down
x + 1e-2, # point further down
) > 0

# correlation with self should be positive
@test k((x, partial(1)), (x, partial(1))) > 0
@testset "Stationary Tests" begin
@test k((x, partial(1)), x) == 0 # expect Cov(∂Z(x) , Z(x)) == 0

@testset "Isotropic Tests" begin
@test k(([1, 2], partial(1)), ([1, 2], partial(2))) == 0 # cross covariance should be zero
end
end
end
end
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, DiffPt, partial
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel, RationalQuadraticKernel
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, EnableDiff, partial
using ProductArrays: productArray
using Test

Expand Down