-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename TensorProduct
and implement TensorCore.tensor
#232
Changes from all commits
864bfab
3336f29
1351d0c
88e2d09
4863c24
28bbb65
96603ce
343fef0
200016d
a1132d2
ea1803b
b982ac8
4456f61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
""" | ||
KernelTensorProduct | ||
|
||
Tensor product of kernels. | ||
|
||
# Definition | ||
|
||
For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor | ||
product of kernels ``k_1, \\ldots, k_n`` is defined as | ||
```math | ||
k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i). | ||
``` | ||
|
||
# Construction | ||
|
||
The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor` | ||
operator or its alias `⊗` (can be typed by `\\otimes<tab>`). | ||
```jldoctest tensorproduct | ||
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2); | ||
|
||
julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2]) | ||
true | ||
``` | ||
|
||
You can also specify a `KernelTensorProduct` by providing kernels as individual arguments | ||
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or | ||
individual arguments guarantees that `KernelTensorProduct` is concretely typed but might | ||
lead to large compilation times if the number of kernels is large. | ||
```jldoctest tensorproduct | ||
julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2 | ||
true | ||
|
||
julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2 | ||
true | ||
|
||
julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2 | ||
true | ||
``` | ||
""" | ||
struct KernelTensorProduct{K} <: Kernel | ||
kernels::K | ||
end | ||
|
||
function KernelTensorProduct(kernel::Kernel, kernels::Kernel...) | ||
return KernelTensorProduct((kernel, kernels...)) | ||
end | ||
|
||
@functor KernelTensorProduct | ||
|
||
Base.length(kernel::KernelTensorProduct) = length(kernel.kernels) | ||
|
||
function (kernel::KernelTensorProduct)(x, y) | ||
if !(length(x) == length(y) == length(kernel)) | ||
throw(DimensionMismatch("number of kernels and number of features | ||
are not consistent")) | ||
end | ||
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) | ||
end | ||
|
||
function validate_domain(k::KernelTensorProduct, x::AbstractVector) | ||
return dim(x) == length(k) || | ||
error("number of kernels and groups of features are not consistent") | ||
end | ||
|
||
# Utility for slicing up inputs. | ||
slices(x::AbstractVector{<:Real}) = (x,) | ||
slices(x::ColVecs) = eachrow(x.X) | ||
slices(x::RowVecs) = eachcol(x.X) | ||
|
||
function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector) | ||
validate_inplace_dims(K, x) | ||
validate_domain(k, x) | ||
|
||
kernels_and_inputs = zip(k.kernels, slices(x)) | ||
kernelmatrix!(K, first(kernels_and_inputs)...) | ||
for (k, xi) in Iterators.drop(kernels_and_inputs, 1) | ||
K .*= kernelmatrix(k, xi) | ||
end | ||
|
||
return K | ||
end | ||
|
||
function kernelmatrix!( | ||
K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector | ||
) | ||
validate_inplace_dims(K, x, y) | ||
validate_domain(k, x) | ||
|
||
kernels_and_inputs = zip(k.kernels, slices(x), slices(y)) | ||
kernelmatrix!(K, first(kernels_and_inputs)...) | ||
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1) | ||
K .*= kernelmatrix(k, xi, yi) | ||
end | ||
|
||
return K | ||
end | ||
|
||
function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector) | ||
validate_inplace_dims(K, x) | ||
validate_domain(k, x) | ||
|
||
kernels_and_inputs = zip(k.kernels, slices(x)) | ||
kerneldiagmatrix!(K, first(kernels_and_inputs)...) | ||
for (k, xi) in Iterators.drop(kernels_and_inputs, 1) | ||
K .*= kerneldiagmatrix(k, xi) | ||
end | ||
|
||
return K | ||
end | ||
|
||
function kernelmatrix(k::KernelTensorProduct, x::AbstractVector) | ||
validate_domain(k, x) | ||
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x)) | ||
end | ||
|
||
function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector) | ||
validate_domain(k, x) | ||
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y)) | ||
end | ||
|
||
function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector) | ||
validate_domain(k, x) | ||
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x)) | ||
end | ||
|
||
Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0) | ||
|
||
function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct) | ||
return ( | ||
length(x.kernels) == length(y.kernels) && | ||
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) | ||
) | ||
end | ||
|
||
function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int) | ||
print(io, "Tensor product of ", length(kernel), " kernels:") | ||
for k in kernel.kernels | ||
print(io, "\n") | ||
for _ in 1:(shift + 1) | ||
print(io, "\t") | ||
end | ||
printshifted(io, k, shift + 2) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
for (M, op, T) in ( | ||
(:Base, :+, :KernelSum), | ||
(:Base, :*, :KernelProduct), | ||
(:TensorCore, :tensor, :KernelTensorProduct), | ||
) | ||
@eval begin | ||
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2) | ||
|
||
$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default behavior is/was that if we have one kernel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nice thing about this default is that you don't end up with abstract types unexpectedly but, of course, it might not be friendly to the compiler. I haven't seen any instances of it in our examples and tests, so I am not sure how often such a combination happens in practice. |
||
function $M.$op( | ||
k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}} | ||
) | ||
return $T(vcat(k1.kernels, k2.kernels)) | ||
end | ||
|
||
$M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...) | ||
$M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels)) | ||
|
||
$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k) | ||
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k)) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only changed
TensorProduct
toKernelTensorProduct
and added a new docstring. I renamed the file in a separate PR but somehow, in contrast to the tests, Github does not display the changes nicely.