Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DI for non-implemented ADTypes #39

Merged
merged 14 commits into from
Nov 2, 2024
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LogDensityProblemsADADTypesExt = "ADTypes"
LogDensityProblemsADDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface"]
LogDensityProblemsADEnzymeExt = "Enzyme"
LogDensityProblemsADFiniteDifferencesExt = "FiniteDifferences"
LogDensityProblemsADForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"]
Expand All @@ -29,6 +31,7 @@ LogDensityProblemsADZygoteExt = "Zygote"

[compat]
ADTypes = "1.5"
DifferentiationInterface = "0.6.1"
BenchmarkTools = "1"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13.3"
Expand All @@ -44,6 +47,7 @@ julia = "1.10"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -54,4 +58,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "BenchmarkTools", "ComponentArrays", "Enzyme", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
test = ["ADTypes", "BenchmarkTools", "ComponentArrays", "DifferentiationInterface", "Enzyme", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ x = zeros(LogDensityProblems.dimension(ℓ)) # ℓ is your log density

5. [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) Finite differences are very robust, with a small numerical error, but usually not fast enough to practically replace AD on nontrivial problems. The backend in this package is mainly intended for checking and debugging results from other backends; but note that in most cases ForwardDiff is faster and more accurate.

PRs for other AD frameworks are welcome, even if they are WIP.
Other AD frameworks are supported thanks to [ADTypes.jl](https://github.com/SciML/ADTypes.jl) and [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).

PRs for remaining AD frameworks are welcome, even if they are WIP.
35 changes: 24 additions & 11 deletions ext/LogDensityProblemsADADTypesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ else
end

"""
ADgradient(ad::ADTypes.AbstractADType, ℓ)
ADgradient(ad::ADTypes.AbstractADType, ℓ; x::Union{Nothing,AbstractVector}=nothing)

Wrap log density `ℓ` using automatic differentiation (AD) of type `ad` to obtain a gradient.

Expand All @@ -19,12 +19,19 @@ Currently,
- `ad::ADTypes.AutoReverseDiff`
- `ad::ADTypes.AutoTracker`
- `ad::ADTypes.AutoZygote`
are supported.
The AD configuration specified by `ad` is forwarded to the corresponding calls of `ADgradient(Val(...), ℓ)`.
are supported with custom implementations.
The AD configuration specified by `ad` is forwarded to the corresponding calls of `ADgradient(Val(...), ℓ)`.

Passing `x` as a keyword argument means that the gradient operator will be "prepared" for the specific type and size of the array `x`. This can speed up further evaluations on similar inputs, but will likely cause errors if the new inputs have a different type or size. With `AutoReverseDiff`, it can also yield incorrect results if the logdensity contains value-dependent control flow.

If you want to use another backend from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) which is not in the list above, you need to load [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) first.
"""
LogDensityProblemsAD.ADgradient(::ADTypes.AbstractADType, ℓ)

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ)
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ; x::Union{Nothing,AbstractVector}=nothing)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
if ad.mode === nothing
# Use default mode (Enzyme.Reverse)
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ)
Expand All @@ -33,25 +40,31 @@ function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ)
end
end

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ) where {C}
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {C}
if C === nothing
# Use default chunk size
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; tag = ad.tag)
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; tag = ad.tag, x=x)
else
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk = C, tag = ad.tag)
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk = C, tag = ad.tag, x=x)
end
end

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ) where {T}
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T))
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {T}
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T), x=x)
end

function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ)
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ)
end


function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ)
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
end

Expand Down
47 changes: 47 additions & 0 deletions ext/LogDensityProblemsADDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module LogDensityProblemsADDifferentiationInterfaceExt

import LogDensityProblemsAD
import ADTypes
import DifferentiationInterface as DI

"""
DIGradient <: LogDensityProblemsAD.ADGradientWrapper

Gradient wrapper which uses [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)

# Fields

- `backend::AbstractADType`: one of the autodiff backend types defined in [ADTypes.jl](https://github.com/SciML/ADTypes.jl), for example `ADTypes.AutoForwardDiff()`
- `prep`: either `nothing` or the output of `DifferentiationInterface.prepare_gradient` applied to the logdensity and the provided input
- `ℓ`: logdensity function, amenable to `LogDensityProblemsAD.logdensity(ℓ, x)`
"""
gdalle marked this conversation as resolved.
Show resolved Hide resolved
struct DIGradient{B<:ADTypes.AbstractADType,P,L} <: LogDensityProblemsAD.ADGradientWrapper
backend::B
gdalle marked this conversation as resolved.
Show resolved Hide resolved
prep::P
ℓ::L
end

function logdensity_switched(x, ℓ)
# active argument must come first in DI
return LogDensityProblemsAD.logdensity(ℓ, x)
end

function LogDensityProblemsAD.ADgradient(backend::ADTypes.AbstractADType, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x === nothing
prep = nothing
else
prep = DI.prepare_gradient(logdensity_switched, backend, x, DI.Constant(ℓ))
end
return DIGradient(backend, prep, ℓ)
end

function LogDensityProblemsAD.logdensity_and_gradient(∇ℓ::DIGradient, x::AbstractVector)
(; backend, prep, ℓ) = ∇ℓ
if prep === nothing
return DI.value_and_gradient(logdensity_switched, backend, x, DI.Constant(ℓ))
else
return DI.value_and_gradient(logdensity_switched, prep, backend, x, DI.Constant(ℓ))
end
end

end
35 changes: 35 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import FiniteDifferences, ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # ba
import ADTypes # load support for AD types with options
import BenchmarkTools # load the heuristic chunks code
using ComponentArrays: ComponentVector # test with other vector types
import DifferentiationInterface

struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end

Expand Down Expand Up @@ -91,6 +92,9 @@ ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(log

# ADTypes support
@test typeof(ADgradient(ADTypes.AutoReverseDiff(; compile = Val(true)), ℓ)) === typeof(∇ℓ_compile)
@test typeof(ADgradient(ADTypes.AutoReverseDiff(; compile = Val(true)), ℓ; x=rand(3))) === typeof(∇ℓ_compile_x)
@test nameof(typeof(ADgradient(ADTypes.AutoReverseDiff(), ℓ))) !== :DIGradient
@test nameof(typeof(ADgradient(ADTypes.AutoReverseDiff(), ℓ; x=rand(3)))) !== :DIGradient

for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile, ∇ℓ_compile, ∇ℓ_compile_x)
@test dimension(∇ℓ) == 3
Expand Down Expand Up @@ -127,6 +131,8 @@ end

# ADTypes support
@test ADgradient(ADTypes.AutoForwardDiff(), ℓ) === ∇ℓ
@test nameof(typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ))) !== :DIGradient
@test nameof(typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ; x=rand(3)))) !== :DIGradient

for _ in 1:100
x = randn(3)
Expand Down Expand Up @@ -175,6 +181,7 @@ end
# ADTypes support
@test ADgradient(ADTypes.AutoForwardDiff(; chunksize = 3), ℓ) === ADgradient(:ForwardDiff, ℓ; chunk = 3)
@test ADgradient(ADTypes.AutoForwardDiff(; chunksize = 3, tag = TestTag()), ℓ) === ADgradient(:ForwardDiff, ℓ; chunk = 3, tag = TestTag())
@test typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ; x=rand(3))) == typeof(ADgradient(:ForwardDiff, ℓ; x=rand(3)))
end

@testset "component vectors" begin
Expand Down Expand Up @@ -211,6 +218,8 @@ end

# ADTypes support
@test ADgradient(ADTypes.AutoTracker(), ℓ) === ∇ℓ
@test nameof(typeof(ADgradient(ADTypes.AutoTracker(), ℓ))) !== :DIGradient
@test nameof(typeof(ADgradient(ADTypes.AutoTracker(), ℓ; x=rand(3)))) !== :DIGradient
end

@testset "AD via Zygote" begin
Expand All @@ -227,6 +236,8 @@ end

# ADTypes support
@test ADgradient(ADTypes.AutoZygote(), ℓ) === ∇ℓ
@test nameof(typeof(ADgradient(ADTypes.AutoZygote(), ℓ))) !== :DIGradient
@test nameof(typeof(ADgradient(ADTypes.AutoZygote(), ℓ; x=rand(3)))) !== :DIGradient
end

@testset "AD via Enzyme" begin
Expand All @@ -241,6 +252,8 @@ end

# ADTypes support
@test ADgradient(ADTypes.AutoEnzyme(), ℓ) === ∇ℓ_reverse
@test nameof(typeof(ADgradient(ADTypes.AutoEnzyme(), ℓ))) !== :DIGradient
@test nameof(typeof(ADgradient(ADTypes.AutoEnzyme(), ℓ; x=rand(3)))) !== :DIGradient

∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
@test ADgradient(ADTypes.AutoEnzyme(;mode=Enzyme.Forward), ℓ) === ∇ℓ_forward
Expand Down Expand Up @@ -291,3 +304,25 @@ end
@test b isa Vector{Pair{Int,Float64}}
@test length(b) ≤ 20
end

@testset verbose=true "DifferentiationInterface for unsupported ADTypes" begin
ℓ = TestLogDensity(test_logdensity1)
backends = [
ADTypes.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1)),
]
∇ℓ_candidates = []
for backend in backends
push!(∇ℓ_candidates, ADgradient(backend, ℓ))
push!(∇ℓ_candidates, ADgradient(backend, ℓ; x=zeros(3)))
end
@testset "$(typeof(∇ℓ))" for ∇ℓ in ∇ℓ_candidates
@test nameof(typeof(∇ℓ)) == :DIGradient
@test dimension(∇ℓ) == 3
@test capabilities(∇ℓ) ≡ LogDensityOrder(1)
for _ in 1:100
x = randn(3)
@test @inferred(logdensity(∇ℓ, x)) ≅ test_logdensity1(x)
@test logdensity_and_gradient(∇ℓ, x) ≅ (test_logdensity1(x), test_gradient(x)) atol = 1e-5
end
end
end;
Loading