Skip to content

Commit

Permalink
Add transformation kernels (#15)
Browse files Browse the repository at this point in the history
* add transformation kernels

* format

* remove unnecessary begin ... end
  • Loading branch information
JoshuaLampert authored Jan 10, 2024
1 parent 96bc4ff commit dfd2c42
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 35 deletions.
42 changes: 42 additions & 0 deletions examples/interpolation_2d_transformation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using KernelInterpolation
using Plots

# function to interpolate
f(x) = x[1] + x[2]^2 < 0.0 ? 1.0 : 0.1

x_min = -1.0
x_max = 1.0
n = 80
nodeset = random_hypercube(n, x_min, x_max; dim = 2)
values = f.(nodeset)

kernel = Matern12Kernel{dim(nodeset)}()
trafo(x) = [x[1] + x[2]^2, 0.0]
trafo_kernel = TransformationKernel{2}(kernel, trafo)
itp = interpolate(nodeset, values, trafo_kernel)
itp_base = interpolate(nodeset, values, kernel)

many_nodes = homogeneous_hypercube(20, x_min, x_max; dim = 2)

abs_diff_trafo = abs.(itp.(many_nodes) .- f.(many_nodes))
abs_diff = abs.(itp_base.(many_nodes) .- f.(many_nodes))

l1_error_trafo = sum(abs_diff_trafo)
l1_error = sum(abs_diff)
linf_error_trafo = maximum(abs_diff_trafo)
linf_error = maximum(abs_diff)

@show l1_error_trafo
@show l1_error
@show linf_error_trafo
@show linf_error

plot(layout = (1, 3))
plot!(many_nodes, trafo_kernel, subplot = 1, st = :surface, cbar = false,
c = cgrad(:grays, rev = true), camera = (0, 90), xguide = "x", yguide = "y")

plot!(many_nodes, itp, subplot = 2)
plot!(many_nodes, f, subplot = 2)

plot!(many_nodes, itp_base, subplot = 3)
plot!(many_nodes, f, subplot = 3)
5 changes: 3 additions & 2 deletions src/KernelInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SpecialFunctions: besselk, loggamma
using StaticArrays
using TypedPolynomials: Variable, monomials, degree

include("kernels.jl")
include("kernels/kernels.jl")
include("nodes.jl")
include("interpolation.jl")
include("visualization.jl")
Expand All @@ -16,7 +16,8 @@ export get_name
export GaussKernel, MultiquadricKernel, InverseMultiquadricKernel,
PolyharmonicSplineKernel, ThinPlateSplineKernel, WendlandKernel,
RadialCharacteristicKernel, MaternKernel, Matern12Kernel, Matern32Kernel,
Matern52Kernel, Matern72Kernel, RieszKernel
Matern52Kernel, Matern72Kernel, RieszKernel,
TransformationKernel
export phi, Phi, order
export NodeSet, separation_distance, dim, values_along_dim, random_hypercube,
random_hypercube_boundary, homogeneous_hypercube, homogeneous_hypercube_boundary,
Expand Down
24 changes: 24 additions & 0 deletions src/kernels/kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
AbstractKernel
An abstract supertype of kernels.
"""
abstract type AbstractKernel{Dim} end

"""
dim(kernel)
Return the dimension of a kernel, i.e. the size of the input vector.
"""
dim(kernel::AbstractKernel{Dim}) where {Dim} = Dim

"""
get_name(kernel::AbstractKernel)
Returns the canonical, human-readable name for the given system of equations.
"""
get_name(kernel::AbstractKernel) = string(nameof(typeof(kernel))) * "{" *
string(dim(kernel)) * "}"

include("radialsymmetric_kernel.jl")
include("special_kernel.jl")
32 changes: 8 additions & 24 deletions src/kernels.jl → src/kernels/radialsymmetric_kernel.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,3 @@
"""
AbstractKernel
An abstract supertype of kernels.
"""
abstract type AbstractKernel{Dim} end

"""
dim(kernel)
Return the dimension of a kernel, i.e. the size of the input vector.
"""
dim(kernel::AbstractKernel{Dim}) where {Dim} = Dim

"""
get_name(kernel::AbstractKernel)
Returns the canonical, human-readable name for the given system of equations.
"""
get_name(kernel::AbstractKernel) = string(nameof(typeof(kernel))) * "{" *
string(dim(kernel)) * "}"

@doc raw"""
RadialSymmetricKernel
Expand All @@ -38,15 +16,21 @@ The kernel is then defined by
abstract type RadialSymmetricKernel{Dim} <: AbstractKernel{Dim} end

function Phi(kernel::RadialSymmetricKernel{Dim}, x) where {Dim}
@assert length(x) == Dim
return phi(kernel, norm(x))
end

function (kernel::RadialSymmetricKernel{Dim})(x, y) where {Dim}
function (kernel::RadialSymmetricKernel)(x, y)
@assert length(x) == length(y)
return Phi(kernel, x .- y)
end

"""
order(kernel)
Return order of kernel.
"""
function order end

@doc raw"""
GaussKernel{Dim}(; shape_parameter = 1.0)
Expand Down
33 changes: 33 additions & 0 deletions src/kernels/special_kernel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@doc raw"""
TransformationKernel(kernel, transformation)
Given a base `kernel` and a bijective `transformation` function, construct
a new kernel that applies the transformation to both arguments ``x`` and ``y``,
i.e. the new kernel ``K_T`` is given by
```math
K_T(x, y) = K(Tx, Ty),
```
where ``K`` is the base kernel and ``T`` the transformation.
"""
struct TransformationKernel{Dim, Kernel, Transformation} <: AbstractKernel{Dim}
kernel::Kernel
trafo::Transformation
end

function TransformationKernel{Dim}(kernel, transformation) where {Dim}
return TransformationKernel{Dim, typeof(kernel), typeof(transformation)}(kernel,
transformation)
end

function (kernel::TransformationKernel)(x, y)
@assert length(x) == length(y)
K = kernel.kernel
T = kernel.trafo
return K(T(x), T(y))
end

function Base.show(io::IO, kernel::TransformationKernel{Dim}) where {Dim}
return print(io, "TransformationKernel{", Dim, "}(kernel = ", kernel.kernel, ")")
end

order(kernel::TransformationKernel) = order(kernel.kernel)
16 changes: 7 additions & 9 deletions src/visualization.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
@recipe function f(x::AbstractVector, kernel::RadialSymmetricKernel)
@series begin
xguide --> "r"
title --> get_name(kernel)
x, phi.(Ref(kernel), abs.(x))
end
@recipe function f(x::AbstractVector, kernel::AbstractKernel)
xguide --> "r"
title --> get_name(kernel)
x, kernel.(Ref(0.0), x)
end

@recipe function f(nodeset::NodeSet, kernel::RadialSymmetricKernel)
@recipe function f(nodeset::NodeSet, kernel::AbstractKernel)
if dim(nodeset) == 1
x = values_along_dim(nodeset, 1)
title --> get_name(kernel)
x, phi.(Ref(kernel), norm.(nodeset))
x, kernel.(Ref(0.0), x)
elseif dim(nodeset) == 2
x = values_along_dim(nodeset, 1)
y = values_along_dim(nodeset, 2)
seriestype --> :scatter
label --> "nodes"
x, y, phi.(Ref(kernel), norm.(nodeset))
x, y, kernel.(Ref([0.0, 0.0]), nodeset)
else
@error("Plotting a kernel is only supported for dimension up to 2, but the set has dimension $(dim(nodeset))")
end
Expand Down
5 changes: 5 additions & 0 deletions test/test_examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ EXAMPLES_DIR = examples_dir()
l2=0.05394435588953249, linf=0.028279879132924693)
end

@ki_testset "interpolation_2d_transformation.jl" begin
@test_include_example(joinpath(EXAMPLES_DIR, "interpolation_2d_transformation.jl"),
l2=0.8382891350633075, linf=0.3927098382304266)
end

@ki_testset "interpolation_5d.jl" begin
@test_include_example(joinpath(EXAMPLES_DIR, "interpolation_5d.jl"),
l2=0.4308925377778874, linf=0.06402624845465965)
Expand Down
14 changes: 14 additions & 0 deletions test/test_unit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ using Plots
@test isapprox(phi(kernel9, 0.0), 0.0)
@test isapprox(phi(kernel9, 0.5), -0.4665164957684037)
@test isapprox(kernel9(x, y), -0.26877021157823217)

trafo(x) = [x[1] + x[2]^2 + 2 * x[3] * x[2], x[3] - x[1]]
kernel10 = @test_nowarn TransformationKernel{3}(kernel1, trafo)
@test_nowarn println(kernel10)
@test_nowarn display(kernel10)
@test order(kernel10) == 0
x3 = [-1.0, 2.0, pi / 8]
y3 = [2.3, 4.2, -12.3]
@test isapprox(kernel10(x3, y3), kernel1(trafo(x3), trafo(y3)))
end

@testset "NodeSet" begin
Expand Down Expand Up @@ -542,12 +551,17 @@ using Plots
@testset "Visualization" begin
f = sum
kernel = GaussKernel{3}(shape_parameter = 0.5)
trafo_kernel = TransformationKernel{2}(kernel, x -> [x[1] + x[2]^2, x[1]])
@test_nowarn plot(-1.0:0.1:1.0, kernel)
for dim in 1:3
nodes = homogeneous_hypercube(5; dim = dim)
@test_nowarn plot(nodes)
if dim < 3
@test_nowarn plot(nodes, kernel)
# Transformtion kernel can only be plotted in the dimension of the input of the trafo
if dim == 2
@test_nowarn plot(nodes, trafo_kernel)
end
ff = f.(nodes)
itp = interpolate(nodes, ff)
nodes_fine = homogeneous_hypercube(10; dim = dim)
Expand Down

0 comments on commit dfd2c42

Please sign in to comment.