Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename TensorProduct and implement TensorCore.tensor #232

Merged
merged 13 commits into from
Jan 20, 2021
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.18"
version = "0.8.19"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -13,6 +13,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -24,5 +25,6 @@ Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
2 changes: 1 addition & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ transform(::Kernel, ::AbstractVector)
ScaledKernel
KernelSum
KernelProduct
TensorProduct
KernelTensorProduct
```

## Multi-output Kernels
Expand Down
10 changes: 7 additions & 3 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ export LinearKernel, PolynomialKernel
export RationalQuadraticKernel, GammaRationalQuadraticKernel
export GaborKernel, PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel
export TensorProduct

export Transform,
SelectTransform,
Expand All @@ -52,6 +51,9 @@ export ColVecs, RowVecs
export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor, ⊗

using Compat
using Requires
using Distances, LinearAlgebra
Expand All @@ -61,6 +63,7 @@ using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase
using TensorCore

abstract type Kernel end
abstract type SimpleKernel <: Kernel end
Expand Down Expand Up @@ -100,7 +103,8 @@ include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("matrix", "kernelmatrix.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "tensorproduct.jl"))
include(joinpath("kernels", "kerneltensorproduct.jl"))
include(joinpath("kernels", "overloads.jl"))
include(joinpath("approximations", "nystrom.jl"))
include("generic.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function spectral_mixture_product_kernel(
if !(size(αs) == size(γs) == size(ωs))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
end
return TensorProduct(
return KernelTensorProduct(
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
)
Expand Down
5 changes: 5 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
)

@deprecate TensorProduct(kernels) KernelTensorProduct(kernels)
@deprecate TensorProduct(kernel::Kernel, kernels::Kernel...) KernelTensorProduct(
kernel, kernels...
)
25 changes: 0 additions & 25 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,6 @@ end

@functor KernelProduct

Base.:*(k1::Kernel, k2::Kernel) = KernelProduct(k1, k2)

function Base.:*(
k1::KernelProduct{<:AbstractVector{<:Kernel}},
k2::KernelProduct{<:AbstractVector{<:Kernel}},
)
return KernelProduct(vcat(k1.kernels, k2.kernels))
end

function Base.:*(k1::KernelProduct, k2::KernelProduct)
return KernelProduct(k1.kernels..., k2.kernels...)
end

function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}})
return KernelProduct(vcat(k, ks.kernels))
end

Base.:*(k::Kernel, kp::KernelProduct) = KernelProduct(k, kp.kernels...)

function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelProduct(vcat(ks.kernels, k))
end

Base.:*(kp::KernelProduct, k::Kernel) = KernelProduct(kp.kernels..., k)

Base.length(k::KernelProduct) = length(k.kernels)

(κ::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
Expand Down
22 changes: 0 additions & 22 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,6 @@ end

@functor KernelSum

Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)

function Base.:+(
k1::KernelSum{<:AbstractVector{<:Kernel}}, k2::KernelSum{<:AbstractVector{<:Kernel}}
)
return KernelSum(vcat(k1.kernels, k2.kernels))
end

Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...)

function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}})
return KernelSum(vcat(k, ks.kernels))
end

Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...)

function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelSum(vcat(ks.kernels, k))
end

Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k)

Base.length(k::KernelSum) = length(k.kernels)

(κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)
Expand Down
144 changes: 144 additions & 0 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only changed TensorProduct to KernelTensorProduct and added a new docstring. I renamed the file in a separate PR but somehow, in contrast to the tests, Github does not display the changes nicely.

KernelTensorProduct

Tensor product of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
product of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor`
operator or its alias `⊗` (can be typed by `\\otimes<tab>`).
```jldoctest tensorproduct
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelTensorProduct` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorProduct` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorproduct
julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2
true

julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2
true

julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2
true
```
"""
struct KernelTensorProduct{K} <: Kernel
kernels::K
end

function KernelTensorProduct(kernel::Kernel, kernels::Kernel...)
return KernelTensorProduct((kernel, kernels...))
end

@functor KernelTensorProduct

Base.length(kernel::KernelTensorProduct) = length(kernel.kernels)

function (kernel::KernelTensorProduct)(x, y)
if !(length(x) == length(y) == length(kernel))
throw(DimensionMismatch("number of kernels and number of features
are not consistent"))
end
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelTensorProduct, x::AbstractVector)
return dim(x) == length(k) ||
error("number of kernels and groups of features are not consistent")
end

# Utility for slicing up inputs.
slices(x::AbstractVector{<:Real}) = (x,)
slices(x::ColVecs) = eachrow(x.X)
slices(x::RowVecs) = eachcol(x.X)

function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi)
end

return K
end

function kernelmatrix!(
K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi, yi)
end

return K
end

function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kerneldiagmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kerneldiagmatrix(k, xi)
end

return K
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
end

function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
end

Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)

function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int)
print(io, "Tensor product of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end
end
22 changes: 22 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)

$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior is/was that if we have one kernel <:Tuple and one kernel <:AbstractVector, one will get a <:Tuple. Is it really a sensible choice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nice thing about this default is that you don't end up with abstract types unexpectedly but, of course, it might not be friendly to the compiler. I haven't seen any instances of it in our examples and tests, so I am not sure how often such a combination happens in practice.

function $M.$op(
k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}}
)
return $T(vcat(k1.kernels, k2.kernels))
end

$M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...)
$M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels))

$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))
end
end
Loading