Skip to content

Commit

Permalink
Add interface for general bases (#100)
Browse files Browse the repository at this point in the history
* add initial interface for general bases

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix least squares in tutorial

* fix SpatialDiscretization

* format

* fix docstring rendering

* fix TemporalDiscretization

* fix unit tests

* add unit tests

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix docstring rendering

* use StandardBasis in some examples

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
JoshuaLampert and github-actions[bot] authored Oct 30, 2024
1 parent cc418aa commit 73ef24a
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 100 deletions.
7 changes: 7 additions & 0 deletions docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ Modules = [KernelInterpolation]
Pages = ["nodes.jl"]
```

## Bases

```@autodocs
Modules = [KernelInterpolation]
Pages = ["basis.jl"]
```

## Interpolation

```@autodocs
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorial_noisy_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ approximation and the polynomial augmentation is not changed. In KernelInterpola
```@example noisy-itp
M = 81
centers = random_hypercube(M; dim = 2)
ls = interpolate(nodeset, centers, values_noisy, kernel)
ls = interpolate(centers, nodeset, values_noisy, kernel)
```

We plot the least-squares approximation and, again, see a better fit to the underlying target function.
Expand Down
3 changes: 2 additions & 1 deletion examples/interpolation/interpolation_2d_sphere.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ merge!(nodeset, nodeset_boundary)
values = f.(nodeset)

kernel = InverseMultiquadricKernel{dim(nodeset)}()
itp = interpolate(nodeset, values, kernel)
basis = StandardBasis(nodeset, kernel)
itp = interpolate(basis, values)

N = 500
many_nodes = random_hypersphere(N, r, center)
Expand Down
2 changes: 1 addition & 1 deletion examples/interpolation/least_squares_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ M = 81
centers = random_hypercube(M; dim = 2)

kernel = ThinPlateSplineKernel{dim(nodeset)}()
ls = interpolate(nodeset, centers, values, kernel)
ls = interpolate(StandardBasis(centers, kernel), values, nodeset)
itp = interpolate(nodeset, values, kernel)

N = 40
Expand Down
4 changes: 3 additions & 1 deletion src/KernelInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ using WriteVTK: WriteVTK, vtk_grid, paraview_collection, MeshCell, VTKCellTypes,

include("kernels/kernels.jl")
include("nodes.jl")
include("basis.jl")
include("regularization.jl")
include("differential_operators.jl")
include("equations.jl")
include("kernel_matrices.jl")
include("regularization.jl")
include("interpolation.jl")
include("discretization.jl")
include("callbacks_step/callbacks_step.jl")
Expand All @@ -48,6 +49,7 @@ export GaussKernel, MultiquadricKernel, InverseMultiquadricKernel,
RadialCharacteristicKernel, MaternKernel, Matern12Kernel, Matern32Kernel,
Matern52Kernel, Matern72Kernel, RieszKernel,
TransformationKernel, ProductKernel, SumKernel
export StandardBasis
export phi, Phi, order
export PartialDerivative, Gradient, Laplacian, EllipticOperator
export PoissonEquation, EllipticEquation, AdvectionEquation, HeatEquation,
Expand Down
73 changes: 73 additions & 0 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
AbstractBasis
Abstract type for a basis of a kernel function space. Every basis represents a
set of functions, which can be obtained by indexing the basis object. Every basis
object holds a kernel function and a [`NodeSet`](@ref) of centers and potentially
more fields depending on the concrete basis type.
"""
abstract type AbstractBasis end

function (basis::AbstractBasis)(x)
return [basis[i](x) for i in eachindex(basis)]
end

"""
interpolation_kernel(basis)
Return the kernel from a basis.
"""
interpolation_kernel(basis::AbstractBasis) = basis.kernel

"""
centers(basis)
Return the centers from a basis object.
"""
centers(basis::AbstractBasis) = basis.centers

"""
order(basis)
Return the order ``m`` of the polynomial, which is needed by this `basis` for
the interpolation, i.e., the polynomial degree plus 1. If ``m = 0``,
no polynomial is added.
"""
order(basis::AbstractBasis) = order(interpolation_kernel(basis))
dim(basis::AbstractBasis) = dim(basis.centers)
Base.length(basis::AbstractBasis) = length(centers(basis))
Base.eachindex(basis::AbstractBasis) = Base.OneTo(length(basis))
function Base.iterate(basis::AbstractBasis, state = 1)
state > length(basis) ? nothing : (basis[state], state + 1)
end
Base.collect(basis::AbstractBasis) = Function[basis[i] for i in 1:length(basis)]

function Base.show(io::IO, basis::AbstractBasis)
return print(io,
"$(nameof(typeof(basis))) with $(length(centers(basis))) centers and kernel $(interpolation_kernel(basis)).")
end

@doc raw"""
StandardBasis(centers, kernel)
The standard basis for a function space defined by a kernel and a [`NodeSet`](@ref) of `centers`.
The basis functions are given by
```math
b_j(x) = K(x, x_j)
```
where `K` is the kernel and `x_j` are the nodes in `centers`.
"""
struct StandardBasis{Kernel} <: AbstractBasis
centers::NodeSet
kernel::Kernel
function StandardBasis(centers::NodeSet, kernel::Kernel) where {Kernel}
if dim(kernel) != dim(centers)
throw(DimensionMismatch("The dimension of the kernel and the centers must be the same"))
end
new{typeof(kernel)}(centers, kernel)
end
end

Base.getindex(basis::StandardBasis, i) = x -> basis.kernel(x, centers(basis)[i])
49 changes: 33 additions & 16 deletions src/discretization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
SpatialDiscretization(equations, nodeset_inner, boundary_condition, nodeset_boundary, basis)
SpatialDiscretization(equations, nodeset_inner, boundary_condition, nodeset_boundary,
[centers,] kernel = GaussKernel{dim(nodeset_inner)}())
Expand All @@ -10,27 +11,35 @@ is a function describing the Dirichlet boundary conditions. The `centers` are th
See also [`Semidiscretization`](@ref), [`solve_stationary`](@ref).
"""
struct SpatialDiscretization{Dim, RealT, Equations, BoundaryCondition,
Kernel <: AbstractKernel{Dim}}
Basis <: AbstractBasis}
equations::Equations
nodeset_inner::NodeSet{Dim, RealT}
boundary_condition::BoundaryCondition
nodeset_boundary::NodeSet{Dim, RealT}
centers::NodeSet{Dim, RealT}
kernel::Kernel
basis::Basis

function SpatialDiscretization(equations, nodeset_inner::NodeSet{Dim, RealT},
boundary_condition,
nodeset_boundary::NodeSet{Dim, RealT},
centers::NodeSet{Dim, RealT},
kernel = GaussKernel{Dim}()) where {Dim,
RealT}
basis::AbstractBasis) where {Dim,
RealT}
new{Dim, RealT, typeof(equations), typeof(boundary_condition),
typeof(kernel)}(equations, nodeset_inner,
boundary_condition, nodeset_boundary,
centers, kernel)
typeof(basis)}(equations, nodeset_inner,
boundary_condition, nodeset_boundary,
basis)
end
end

function SpatialDiscretization(equations, nodeset_inner::NodeSet{Dim, RealT},
boundary_condition,
nodeset_boundary::NodeSet{Dim, RealT},
centers::NodeSet{Dim, RealT},
kernel = GaussKernel{Dim}()) where {Dim,
RealT}
SpatialDiscretization(equations, nodeset_inner, boundary_condition,
nodeset_boundary, StandardBasis(centers, kernel))
end

function SpatialDiscretization(equations, nodeset_inner::NodeSet{Dim, RealT},
boundary_condition,
nodeset_boundary::NodeSet{Dim, RealT},
Expand All @@ -42,8 +51,11 @@ function SpatialDiscretization(equations, nodeset_inner::NodeSet{Dim, RealT},
end

function Base.show(io::IO, sd::SpatialDiscretization)
N_i = length(sd.nodeset_inner)
N_b = length(sd.nodeset_boundary)
k = interpolation_kernel(sd.basis)
print(io,
"SpatialDiscretization with $(dim(sd)) dimensions, $(length(sd.nodeset_inner)) inner nodes, $(length(sd.nodeset_boundary)) boundary nodes, and kernel $(sd.kernel)")
"SpatialDiscretization with $(dim(sd)) dimensions, $N_i inner nodes, $N_b boundary nodes, and kernel $k")
end

dim(::SpatialDiscretization{Dim}) where {Dim} = Dim
Expand All @@ -60,7 +72,8 @@ function solve_stationary(spatial_discretization::SpatialDiscretization{Dim, Rea
Dim,
RealT
}
@unpack equations, nodeset_inner, boundary_condition, nodeset_boundary, centers, kernel = spatial_discretization
@unpack equations, nodeset_inner, boundary_condition, nodeset_boundary, basis = spatial_discretization
@unpack centers, kernel = basis

system_matrix = pde_boundary_matrix(equations, nodeset_inner, nodeset_boundary, centers,
kernel)
Expand All @@ -71,7 +84,7 @@ function solve_stationary(spatial_discretization::SpatialDiscretization{Dim, Rea
xx = polyvars(Dim)
ps = monomials(xx, 0:-1)
nodeset = merge(nodeset_inner, nodeset_boundary)
return Interpolation(kernel, nodeset, centers, c, system_matrix,
return Interpolation(basis, nodeset, c, system_matrix,
ps, xx)
end

Expand All @@ -96,10 +109,11 @@ end

function Semidiscretization(spatial_discretization::SpatialDiscretization,
initial_condition)
@unpack equations, nodeset_inner, boundary_condition, nodeset_boundary, centers, kernel = spatial_discretization
@unpack equations, nodeset_inner, boundary_condition, nodeset_boundary, basis = spatial_discretization
@unpack centers, kernel = basis
@assert length(centers)==length(nodeset_inner) + length(nodeset_boundary) "The number of centers must be equal to the number of inner and boundary nodes."
k_matrix_inner = kernel_matrix(nodeset_inner, centers, kernel)
k_matrix_boundary = kernel_matrix(nodeset_boundary, centers, kernel)
k_matrix_inner = kernel_matrix(centers, nodeset_inner, kernel)
k_matrix_boundary = kernel_matrix(centers, nodeset_boundary, kernel)
# whole kernel matrix is not needed for rhs, but for initial condition
k_matrix = [k_matrix_inner
k_matrix_boundary]
Expand Down Expand Up @@ -133,8 +147,11 @@ function Semidiscretization(equations, nodeset_inner::NodeSet{Dim, RealT},
end

function Base.show(io::IO, semi::Semidiscretization)
N_i = length(semi.spatial_discretization.nodeset_inner)
N_b = length(semi.spatial_discretization.nodeset_boundary)
k = interpolation_kernel(semi.spatial_discretization.basis)
print(io,
"Semidiscretization with $(dim(semi)) dimensions, $(length(semi.spatial_discretization.nodeset_inner)) inner nodes, $(length(semi.spatial_discretization.nodeset_boundary)) boundary nodes, and kernel $(semi.spatial_discretization.kernel)")
"Semidiscretization with $(dim(semi)) dimensions, $N_i inner nodes, $N_b boundary nodes, and kernel $k")
end

dim(semi::Semidiscretization) = dim(semi.spatial_discretization)
Expand Down
Loading

0 comments on commit 73ef24a

Please sign in to comment.