Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second derivative of Matern in zero is wrong #517

Open
FelixBenning opened this issue Jun 5, 2023 · 23 comments
Open

Second derivative of Matern in zero is wrong #517

FelixBenning opened this issue Jun 5, 2023 · 23 comments

Comments

@FelixBenning
Copy link

FelixBenning commented Jun 5, 2023

julia> using KernelFunctions: MaternKernel

julia> k = MaternKernel=5)
Matern Kernel (ν = 5, metric = Distances.Euclidean(0.0))

julia> import ForwardDiff as FD

julia> kx(x,y) = FD.derivative(t -> k(x+t, y), 0)
kx (generic function with 1 method)

julia> dk(x,y) = FD.derivative(t -> kx(x, y+t), 0)
dk (generic function with 1 method)

julia> dk(0,0)
0.0

This is wrong, because for a centered GP $Z$ with covariance function $k$

$$dk(x,y) = \partial_x \partial_y k(x,y) = \partial_x \partial_y \mathbb{E}[Z(x),Z(y)] = \mathbb{E}[\partial_x Z(x) \partial_y Z(y)] = \text{Cov}(Z'(x), Z'(y))$$

And $\text{Cov}(Z'(0), Z'(0)) >0$.

$\nu=5$ should be plenty of space for numerical errors since this implies the GP is 5 times differentiable.

@FelixBenning
Copy link
Author

This is most likely due to this

function _matern::Real, d::Real)
    if iszero(d)
        return one(d)
    else
        y = sqrt(2ν) * d
        b = log(besselk(ν, y))
        return exp((one(d) - ν) * oftype(y, logtwo) - loggamma(ν) + ν * log(y) + b)
    end
end

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/basekernels/matern.jl#L43-L45

The if block should not be constant but rather a taylor polynomial so that autodiff in this branch works. I am looking for a reference...

@devmotion
Copy link
Member

devmotion commented Jun 5, 2023

Might be due to the hardcoded (constant) value for x = y. Hence possibly the problem doesn't exist with ForwardDiff 0.11.

@devmotion
Copy link
Member

The if block should not be constant but rather a taylor polynomial so that autodiff in this branch works.

Shouldn't be necessary in ForwardDiff 0.11: It skips measure zero branches (which fixes some problems but broke existing code).

@FelixBenning
Copy link
Author

@FelixBenning
Copy link
Author

Shouldn't be necessary in ForwardDiff 0.11: It skips measure zero branches (which fixes some problems but broke existing code).

wow neat - when is that going to be released?

@devmotion
Copy link
Member

The package and paper is mainly concerned with derivatives wrt the order, which does not seem to be the issue in the OP.

when is that going to be released?

I don't know. Initially the change was released in a non-breaking 0.10.X version but it broke a lot of downstream packages that rely on the current behaviour. So it was reverted and re-applied to the master branch but releases are only made in a separate 0.10.X branch without this change recently. I don't think anyone plans to release a 0.11 any time soon because the same thing will happen and nobody wants to invest the time to fix all the broken downstream code.

@FelixBenning
Copy link
Author

FelixBenning commented Jun 5, 2023

@devmotion you are probably right about the references - damn I thought I have seen that somehwere. Can't find it at the moment. Given that the variance of the derivatives is really important for everything, this is not really a corner case and will break working with derivatives of GPs...

FelixBenning added a commit to FelixBenning/DifferentiableKernelFunctions.jl that referenced this issue Jun 5, 2023
FelixBenning added a commit to FelixBenning/DifferentiableKernelFunctions.jl that referenced this issue Jun 5, 2023
…t) (#10)

* enableDiffWrap instead of general

* jldoctest

* trim jldoctest

* warning docs, shorten name

* identify broken tests

* push broken state (upstream issue)

JuliaGaussianProcesses/KernelFunctions.jl#517
@FelixBenning
Copy link
Author

FelixBenning commented Jun 5, 2023

Are you interested in a taylor expansion for the if x==y branch? Given the AD promise

Automatic Differentiation compatibility: all kernel functions which ought to be differentiable using AD packages like ForwardDiff.jl or Zygote.jl should be.

Same issue with the rationalQuadratic btw (although that does not appear to branch in the same way)

julia> using KernelFunctions: RationalQuadraticKernel

julia> k = RationalQuadraticKernel()
Rational Quadratic Kernel (α = 2.0, metric = Distances.Euclidean(0.0))

julia> kx(x,y) = FD.derivative(t->k(x+t, y), 0)
kx (generic function with 1 method)

julia> dk(x,y) = FD.derivative(t->k(x,y+t), 0)
dx (generic function with 1 method)

julia> dk(0,0)
0.0

EDIT: Wikipedia https://en.wikipedia.org/wiki/Mat%C3%A9rn_covariance_function#Taylor_series_at_zero_and_spectral_moments

@Crown421
Copy link
Member

Crown421 commented Jun 6, 2023

I think the Matern kernel is in fact correct:
Made a mistake here, Matern is incorrect,but RationalQuadratic is fine

using KernelFunctions, Enzyme, Plots
k = Matern52Kernel()
mk(d) = k(1.0, 1.0 + d)

mk(0.1)

dr = range(0.0, 0.1; length=200)

p = plot(layout=(1, 2))
plot!(p[1], dr, mk.(dr), label="matern", legend=:topright)

dmkd(x, y) = only(autodiff(
    Forward,
    yt -> only(autodiff_deferred(Forward, k, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
    DuplicatedNoNeed,
    Duplicated(y, 1.0)))
dmkd(1.0, 1.0 + 0.1)

plot!(p[2], dr, dmkd.(1.0, 1.0 .+ dr), label="matern", legend=:topright, title="d^2/(dx1 dx2) ker")

image

but for RationalQuadratic

using KernelFunctions, Enzyme, Plots
k = RationalQuadraticKernel()
mk(d) = k(1.0, 1.0 + d)

mk(0.1)

dr = range(0.0, 0.1; length=200)

p = plot(layout=(1, 2))
plot!(p[1], dr, mk.(dr), label="RQ", legend=:topright)

dmkd(x, y) = only(autodiff(
    Forward,
    yt -> only(autodiff_deferred(Forward, k, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
    DuplicatedNoNeed,
    Duplicated(y, 1.0)))
dmkd(1.0, 1.0 + 0.1)

plot!(p[2], dr, dmkd.(1.0, 1.0 .+ dr), label="RQ", legend=:right, title="d^2/(dx1 dx2) ker")

image

@FelixBenning
Copy link
Author

I think the Matern kernel is in fact correct: Made a mistake here, Matern is incorrect,but RationalQuadratic is fine

neat so switching to enzyme would at least fix the rational quadratic.

@Crown421
Copy link
Member

I have been working on this for a bit, and it seems the issue is not the matern kernel, or that the Taylor Expansion is needed, but instead it seems to related to the Euclidean distance. In the following code I did it all by hand, and get

using KernelFunctions
using Enzyme
using Plots

k = Matern52Kernel()
mk(d) = k(1.0, 1.0 + d)

kappa(d) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)

r = range(0, 0.2, length=30)

begin
    dist1(x, y) = sqrt((x - y)^2)
    dist2(x, y) = abs(x - y)
    ck1(x, y) = kappa(dist1(x, y))
    ck2(x, y) = kappa(dist2(x, y))

    p = plot(layout=(2, 1))
    plot!(p[1], r, ck1.(1.0, 1.0 .+ r), label="dist1")
    plot!(p[1], r, ck2.(1.0, 1.0 .+ r), label="dist2")
    plot!(p[1], r, mk.(r), label="ref")

    dmkd1(x, y) = only(autodiff(
        Forward,
        yt -> only(autodiff_deferred(Forward, ck1, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
        DuplicatedNoNeed,
        Duplicated(y, 1.0)))
    dmkd2(x, y) = only(autodiff(
        Forward,
        yt -> only(autodiff_deferred(Forward, ck2, DuplicatedNoNeed, Duplicated(x, 1.0), yt)),
        DuplicatedNoNeed,
        Duplicated(y, 1.0)))

    plot!(p[2], r, dmkd1.(1.0, 1.0 .+ r), label="dist1")
    plot!(p[2], r, dmkd2.(1.0, 1.0 .+ r), label="dist2")
end

image

Repeating the same with the distance works correctly, so there seems to be some weird issue with the chain rule.
image

@FelixBenning
Copy link
Author

FelixBenning commented Aug 24, 2023

If I understand your code correctly, you reimplemented kappa (in the first plot you make sure that it results in the same function by comparing it to the reference implementation) and then you take the derivative of this kappa.

But your kappa does not have an if iszero(d) like the general matern implementation does cf.

https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/master/src/basekernels/matern.jl#L43-L45

As autodiff simply takes the derivative of the branch it finds itself in, it will take the derivative of this if case (if d is zero). And the derivative of a constant is zero.

@devmotion
Copy link
Member

As autodiff simply takes the derivative of the branch it finds itself in

That's not true in general. ForwardDiff#master is supposed to ignore branches of measure zero.

@FelixBenning
Copy link
Author

That's not true in general. ForwardDiff#master is supposed to ignore branches of measure zero.

But it doesn't at the moment and not in the forseeable future as I understood from the reaction when I asked around about it. And the point of my comment was to explain what the issue was at the moment (and this is probably it).

@Crown421
Copy link
Member

Crown421 commented Aug 25, 2023

But your kappa does not have an if iszero(d) like the general matern implementation does

From my understanding and experiments, the general Matern implementation doesn't matter for Matern52. It is a specialized implementation, and _matern never enters into it.

Further down in the file you linked, you can see the code for Matern52.

@FelixBenning
Copy link
Author

I didn't know that Matern52 also breaks. I thought only the general version was a problem. I guess in this case we have two problems (the branch in the general case) AND the distance function.

@Crown421
Copy link
Member

Crown421 commented Sep 6, 2023

Continuing to look into it, it seems the issues with sqrt. The Euclidean distance is (effectively) defined as sqrt((x-y)^2), which then causes issues in x==y. See also this long discussion in Enzyme.

This also explains why ForwardDiff.jl failed. KernelFunctions defines a custom rule for this case (via ChainRules.jl), but ForwardDiff does not use ChainRules as far as I can tell.

@wsmoses
Copy link

wsmoses commented Sep 6, 2023

Apologies as I haven't read further up and got linked here from slack, I mean the linked Enzyme issue there is that the derivative of sqrt is undefined at 0. We chose to have it as 0 there instead of nan, which the linked research says is good for a variety of reasons, though added an extra instruction.

Cc @martinjm97

@devmotion
Copy link
Member

devmotion commented Sep 6, 2023

No, ForwardDiff uses its own definitions (functions with Dual arguments) and DiffRules. The sqrt example reminds me of JuliaDiff/DiffRules.jl#100 which reverted a ForwardDiff bug introduced by defining the derivative of abs(x) as sign(x) instead of signbit(x) ? -one(x) : one(x). Clearly the only difference is how the AD derivative at 0 is defined (obviously the function is not differentiable at 0) - but setting it to 0 (which is a valid subderivative) breaks the hessian example in the linked PR (even on ForwardDiff#master). Hence generally I think it's better to define "derivatives" at non-differentiable points by evaluating them as the left and hand side derivatives than setting them to some constant (!) subderivative (as argued above in this issue here as well IIRC).

@Crown421
Copy link
Member

Crown421 commented Sep 6, 2023

@devmotion Do you think adding something like this special rule would be the solution in this context?

@wsmoses
Copy link

wsmoses commented Sep 6, 2023

@devmotion fwiw that's how Enzyme defines its abs derivative (see below). Sqrt, however, cannot have this though.

julia> using Enzyme
julia> Enzyme.API.printall!(true)

julia> Enzyme.autodiff(Forward, abs, Duplicated(0.0, 1.0))
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_abs_615_inner.1(double %0) local_unnamed_addr #3 !dbg !10 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #4
  %2 = call double @llvm.fabs.f64(double %0) #4, !dbg !11
  ret double %2, !dbg !13
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal double @fwddiffejulia_abs_615_inner.1(double %0, double %"'") local_unnamed_addr #3 !dbg !14 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #4
  %2 = fcmp fast olt double %0, 0.000000e+00, !dbg !15
  %3 = select fast i1 %2, double -1.000000e+00, double 1.000000e+00, !dbg !15
  %4 = fmul fast double %"'", %3, !dbg !15
  ret double %4
}

(1.0,)

@Crown421
Copy link
Member

Crown421 commented Sep 6, 2023

@wsmoses A bit of a side issue, but would it make sense to have that API.printall! function in the Enzyme documentation? I was initially looking for something like this and could not find it.

@wsmoses
Copy link

wsmoses commented Sep 6, 2023

Oh yeah go for it, contributions welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants