diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 270363e..64009d3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,6 +27,7 @@ jobs: - x64 group: - Others + - Enzyme - ForwardDiff - Tracker - ReverseDiff @@ -36,6 +37,10 @@ jobs: os: macOS-latest arch: x64 group: Others + - version: '1' + os: macOS-latest + arch: x64 + group: Enzyme - version: '1' os: macOS-latest arch: x64 @@ -53,23 +58,14 @@ jobs: arch: x64 group: Zygote steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 with: coverage: false env: diff --git a/test/Project.toml b/test/Project.toml index 086a88d..3f4bf7e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -21,6 +22,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1.9.2" Combinatorics = "1.0.2" Distributions = "0.25.15" +Enzyme = "0.12" FiniteDifferences = "0.11.3, 0.12" ForwardDiff = "0.10.12" LazyArrays = "1" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 8baa50c..699bc7a 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -61,8 +61,9 @@ DistSpec(Poisson, (0.5,), 1), DistSpec(Poisson, (0.5,), [1, 1]), - DistSpec(Skellam, (1.0, 2.0), -2), - DistSpec(Skellam, (1.0, 2.0), [-2, -2]), + # Enzyme: no forward rule for ccall + DistSpec(Skellam, (1.0, 2.0), -2, broken=(:Enzyme,)), + DistSpec(Skellam, (1.0, 2.0), [-2, -2], broken=(:Enzyme,)), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), @@ -159,10 +160,10 @@ DistSpec(LogNormal, (1.0, 2.0), 0.5), # Dispatch error caused by ccall - DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), - DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), + DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Enzyme, :Tracker, :ForwardDiff, :Zygote, :ReverseDiff, :Enzyme)), DistSpec(Normal, (), 0.5), DistSpec(Normal, (1.0,), 0.5), diff --git a/test/ad/others.jl b/test/ad/others.jl index 1c70558..94d1492 100644 --- a/test/ad/others.jl +++ b/test/ad/others.jl @@ -14,7 +14,7 @@ A = to_posdef(rand(3, 3)) B = to_posdef(rand(3, 3)) - test_reverse_mode_ad(randn(3, 3), A, B) do A, B + test_reverse_mode_ad(randn(3, 3), A, B; broken = (:Enzyme,)) do A, B return DistributionsAD.zygote_ldiv(A, B) end end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index a4dcd6e..275b19c 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -5,6 +5,11 @@ using FiniteDifferences const FDM = FiniteDifferences # Load AD backends +if GROUP == "All" || GROUP == "Enzyme" + @eval begin + using Enzyme + end +end if GROUP == "All" || GROUP == "ForwardDiff" @eval using ForwardDiff end @@ -18,13 +23,63 @@ if GROUP == "All" || GROUP == "Tracker" @eval using Tracker end -function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6) +function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6, broken=()) # Perform a regular forwards-pass. y = f(x...) # Use finite differencing to compute reverse-mode sensitivities. x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...) + if GROUP == "All" || GROUP == "Enzyme" + enzyme_broken = :Enzyme in broken + io = enzyme_broken ? devnull : stdout + testset = redirect_stdout(io) do + # Use Enzyme to compute reverse-mode sensitivities. + @testset "Enzyme: Reverse-mode AD of $f" begin + x̄s_enzyme_init = map(x) do xi + xi isa Real ? nothing : zero(xi) + end + enzyme_autodiff_args = map(x, x̄s_enzyme_init) do xi, x̄si + return if x̄si === nothing + @assert xi isa Real + Active(xi) + else + @assert typeof(xi) === typeof(x̄si) + Duplicated(xi, x̄si) + end + end + dot_f_ȳ(args...) = dot(f(args...), ȳ) + x̄s_enzyme_autodiff, y_dot_ȳ_enzyme = + Enzyme.autodiff(ReverseWithPrimal, Const(dot_f_ȳ), Active, enzyme_autodiff_args...) + x̄s_enzyme = map(x̄s_enzyme_init, x̄s_enzyme_autodiff) do x̄s_init_i, x̄s_autodiff_i + return if x̄s_init_i === nothing + @assert x̄s_autodiff_i isa Real + x̄s_autodiff_i + else + @assert x̄s_autodiff_i === nothing + x̄s_init_i + end + end + + # Check that Enzyme primal is correct. + @test dot(y, ȳ) ≈ y_dot_ȳ_enzyme atol=atol rtol=rtol + + # Check that Enzyme reverse-mode sensitivities are correct. + @test all(zip(x̄s_enzyme, x̄s_fdm)) do (x̄_enzyme, x̄_fdm) + return isapprox(x̄_enzyme, x̄_fdm; atol=atol, rtol=rtol) + end + end + end + + # change errors and fails to broken results, and count number of errors and fails + efs = errors_to_broken!(testset) + + # ensure that passing tests are not marked as broken + if iszero(efs) && enzyme_broken + error("Enzyme tests of $f passed unexpectedly, please mark not as broken") + end + end + if GROUP == "All" || GROUP == "Zygote" # Use Zygote to compute reverse-mode sensitivities. y_zygote, back_zygote = Zygote.pullback(f, x...) @@ -350,6 +405,19 @@ end function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) finitediff = FDM.grad(central_fdm(5, 1), f, x)[1] + if GROUP == "All" || GROUP == "Enzyme" + if (:Enzyme in broken) || (:EnzymeForward in broken) + @test_broken collect(Enzyme.gradient(Enzyme.Forward, Const(f), x)) ≈ finitediff rtol=rtol atol=atol + else + @test collect(Enzyme.gradient(Enzyme.Forward, Const(f), x)) ≈ finitediff rtol=rtol atol=atol + end + if (:Enzyme in broken) || (:EnzymeReverse in broken) + @test_broken Enzyme.gradient(Enzyme.Reverse, Const(f), x) ≈ finitediff rtol=rtol atol=atol + else + @test Enzyme.gradient(Enzyme.Reverse, Const(f), x) ≈ finitediff rtol=rtol atol=atol + end + end + if GROUP == "All" || GROUP == "Tracker" if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol @@ -398,7 +466,7 @@ end function testset_zygote_broken(distspec, args...; kwargs...) # don't show test errors - tests are known to be broken :) - testset = suppress_stdout() do + testset = redirect_stdout(devnull) do testset_zygote(distspec, args...; kwargs...) end @@ -417,17 +485,6 @@ function testset_zygote_broken(distspec, args...; kwargs...) return testset end -# `redirect_stdout(f, devnull)` is only available in Julia >= 1.6 -function suppress_stdout(f) - @static if VERSION < v"1.6" - open((@static Sys.iswindows() ? "NUL" : "/dev/null"), "w") do devnull - redirect_stdout(f, devnull) - end - else - redirect_stdout(f, devnull) - end -end - # change test errors and failures to broken results function errors_to_broken!(ts::Test.DefaultTestSet) results = ts.results diff --git a/test/runtests.jl b/test/runtests.jl index baee8ad..4075503 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,13 @@ using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringDirichlet using StatsFuns: StatsFuns, logsumexp, logistic +import Enzyme +Enzyme.API.typeWarning!(false) +# Enable runtime activity (workaround) +Enzyme.API.runtimeActivity!(true) +# Supress excessive type deduce failures may result in incorrect gradients. +# https://enzyme.mit.edu/julia/stable/api/#Enzyme.API.looseTypeAnalysis!-Tuple{Any} +Enzyme.API.looseTypeAnalysis!(true) @static if VERSION >= v"1.8" using Pkg; Pkg.status(outdated=true) # show reasons why packages are held back end @@ -26,7 +33,7 @@ if GROUP == "All" || GROUP == "Others" include("others.jl") end -if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker") +if GROUP == "All" || GROUP in ("Enzyme", "ForwardDiff", "Zygote", "ReverseDiff", "Tracker") include("ad/utils.jl") include("ad/others.jl") include("ad/distributions.jl")