diff --git a/.github/workflows/ForwardDiff_Tracker.yml b/.github/workflows/AD.yml similarity index 81% rename from .github/workflows/ForwardDiff_Tracker.yml rename to .github/workflows/AD.yml index cf30e2b5..ebbd9179 100644 --- a/.github/workflows/ForwardDiff_Tracker.yml +++ b/.github/workflows/AD.yml @@ -1,4 +1,4 @@ -name: ForwardDiff and Tracker tests +name: AD tests on: push: @@ -20,6 +20,11 @@ jobs: - macOS-latest arch: - x64 + AD: + - ForwardDiff + - Tracker + - ReverseDiff + - Zygote steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -30,4 +35,4 @@ jobs: - uses: julia-actions/julia-runtest@latest env: GROUP: AD - AD: ForwardDiff_Tracker + AD: ${{ matrix.AD }} diff --git a/.github/workflows/ReverseDiff.yml b/.github/workflows/ReverseDiff.yml deleted file mode 100644 index 6c5abb98..00000000 --- a/.github/workflows/ReverseDiff.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: ReverseDiff tests - -on: - push: - branches: - - master - pull_request: - -jobs: - test: - runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} - strategy: - matrix: - version: - - '1.3' - - '1' - os: - - ubuntu-latest - - macOS-latest - arch: - - x64 - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - env: - GROUP: AD - AD: ReverseDiff diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml deleted file mode 100644 index f03c17ab..00000000 --- a/.github/workflows/Zygote.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Zygote tests - -on: - push: - branches: - - master - pull_request: - -jobs: - test: - runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} - strategy: - matrix: - version: - - '1.3' - - '1' - os: - - ubuntu-latest - - macOS-latest - arch: - - x64 - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - env: - GROUP: AD - AD: Zygote diff --git a/Project.toml b/Project.toml index d13bc7d1..726f47ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.8.7" +version = "0.8.8" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -20,7 +20,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ArgCheck = "1, 2" Compat = "3" -Distributions = "0.23.3" +Distributions = "0.23.3, 0.24" MappedArrays = "0.2.2, 0.3" NNlib = "0.6, 0.7" Reexport = "0.2" diff --git a/test/Project.toml b/test/Project.toml index d3d44d70..25db89ce 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,10 @@ [deps] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -13,8 +14,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Combinatorics = "1.0.2" DistributionsAD = "0.6.3" -FiniteDiff = "2.6" +FiniteDifferences = "0.11" ForwardDiff = "0.10.12" +NNlib = "0.7" ReverseDiff = "1.4.2" Tracker = "0.2.11" Zygote = "0.5.4" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 7f017fe0..506d95d7 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -10,15 +10,16 @@ B = rand(dim, dim) C = rand(dim, dim) - dim_big = 10 + # Create random numbers + alpha = rand() + beta = rand() + gamma = rand() # Some LKJ problems may be hidden when test matrix is too small - A_big = rand(dim_big, dim_big) + dim_big = 10 + A_big = rand(dim_big, dim_big) B_big = rand(dim_big, dim_big) - # Create a random number - alpha = rand() - # Create positive definite matrix to_posdef(A::AbstractMatrix) = A * A' + I to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1) @@ -34,13 +35,21 @@ return S, pullback end - # Create matrix `X` such that `X` and `I - X` are positive definite + # Create matrix `X` such that `X` and `I - X` are positive definite if `A ≠ 0`. function to_beta_mat(A) S = A * A' + I invL = inv(cholesky(S).L) return invL * invL' end + # Create positive values. + to_positive(x) = exp.(x) + to_positive(x::AbstractArray{<:AbstractArray}) = to_positive.(x) + + # Create vectors in probability simplex. + to_simplex(x::AbstractArray; dims=1) = NNlib.softmax(x; dims=dims) + to_simplex(x::AbstractArray{<:AbstractArray}; dims=1) = to_simplex.(x; dims=dims) + function to_corr(x) y = to_posdef(x) d = 1 ./ sqrt.(diag(y)) @@ -48,6 +57,8 @@ return (y2 + y2') / 2 end + # Tests that have a `broken` field can be executed but, according to FiniteDifferences, + # fail to produce the correct result. These tests can be checked with `@test_broken`. univariate_distributions = DistSpec[ ## Univariate discrete distributions @@ -105,6 +116,8 @@ DistSpec(Cauchy, (1.0,), 0.5), DistSpec(Cauchy, (1.0, 2.0), 0.5), + DistSpec(Chernoff, (), 0.5, broken=(:Zygote,)), + DistSpec(Chi, (1.0,), 0.5), DistSpec(Chisq, (1.0,), 0.5), @@ -169,6 +182,12 @@ DistSpec(LogNormal, (1.0,), 0.5), DistSpec(LogNormal, (1.0, 2.0), 0.5), + # Dispatch error caused by ccall + DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralChisq, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralF, (1.0, 2.0, 1.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(NoncentralT, (1.0, 2.0), 0.5, broken=(:Tracker, :ForwardDiff, :Zygote, :ReverseDiff)), + DistSpec(Normal, (), 0.5), DistSpec(Normal, (1.0,), 0.5), DistSpec(Normal, (1.0, 2.0), 0.5), @@ -205,10 +224,10 @@ ), DistSpec(Uniform, (), 0.5), - DistSpec(Uniform, (0.0, 1.0), 0.5), + DistSpec(Uniform, (alpha, alpha + beta), alpha + beta * gamma), DistSpec(TuringUniform, (), 0.5), - DistSpec(TuringUniform, (0.0, 1.0), 0.5), + DistSpec(TuringUniform, (alpha, alpha + beta), alpha + beta * gamma), DistSpec(VonMises, (), 1.0), @@ -217,31 +236,26 @@ DistSpec(Weibull, (1.0, 1.0), 1.0), ] + # Tests cannot be executed, so cannot be checked with `@test_broken`. broken_univariate_distributions = DistSpec[ - # Zygote - DistSpec(Chernoff, (), 0.5), - # Broken in Distributions even without autodiff - DistSpec(() -> KSDist(1), (), 0.5), - DistSpec(() -> KSOneSided(1), (), 0.5), - DistSpec(StudentizedRange, (1.0, 2.0), 0.5), - - # Dispatch error caused by ccall - DistSpec(NoncentralBeta, (1.0, 2.0, 1.0), 0.5), - DistSpec(NoncentralChisq, (1.0, 2.0), 0.5), - DistSpec(NoncentralF, (1, 2, 1), 0.5), - DistSpec(NoncentralT, (1, 2), 0.5), + DistSpec(() -> KSDist(1), (), 0.5), # `pdf` method not defined + DistSpec(() -> KSOneSided(1), (), 0.5), # `pdf` method not defined + DistSpec(StudentizedRange, (1.0, 2.0), 0.5), # `srdistlogpdf` method not defined # Stackoverflow caused by SpecialFunctions.besselix DistSpec(VonMises, (1.0,), 1.0), DistSpec(VonMises, (1, 1), 1), ] + # Tests that have a `broken` field can be executed but, according to FiniteDifferences, + # fail to produce the correct result. These tests can be checked with `@test_broken`. multivariate_distributions = DistSpec[ ## Multivariate discrete distributions # Vector x DistSpec(p -> Multinomial(2, p ./ sum(p)), (fill(0.5, 2),), [2, 0]), + DistSpec(p -> Multinomial(2, p ./ sum(p)), (fill(0.5, 2),), [2 1; 0 1]), # Vector x DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), b), @@ -262,16 +276,16 @@ DistSpec(TuringMvNormal, (a,), b), DistSpec(s -> TuringMvNormal(to_posdef_diagonal(s)), (a,), b), DistSpec(s -> TuringMvNormal(dim, s), (alpha,), a), - DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), b), - DistSpec(MvLogNormal, (a, b), c), - DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), c), - DistSpec(MvLogNormal, (a, alpha), b), - DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), a), - DistSpec(MvLogNormal, (a,), b), - DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), b), - DistSpec(s -> MvLogNormal(dim, s), (alpha,), a), + DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), b, to_positive), + DistSpec(MvLogNormal, (a, b), c, to_positive), + DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), c, to_positive), + DistSpec(MvLogNormal, (a, alpha), b, to_positive), + DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), a, to_positive), + DistSpec(MvLogNormal, (a,), b, to_positive), + DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), b, to_positive), + DistSpec(s -> MvLogNormal(dim, s), (alpha,), a, to_positive), - DistSpec(Dirichlet, (ones(dim),), b ./ sum(b)), + DistSpec(alpha -> Dirichlet(to_positive(alpha)), (a,), b, to_simplex), # Matrix case DistSpec(MvNormal, (a, b), A), @@ -281,18 +295,22 @@ DistSpec(MvNormal, (a,), A), DistSpec(s -> MvNormal(to_posdef_diagonal(s)), (a,), A), DistSpec(s -> MvNormal(dim, s), (alpha,), A), - DistSpec(MvLogNormal, (a, b), A), - DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), A), - DistSpec(MvLogNormal, (a, alpha), A), - DistSpec(MvLogNormal, (a,), A), - DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), A), - DistSpec(s -> MvLogNormal(dim, s), (alpha,), A), - - DistSpec(Dirichlet, (ones(dim),), B ./ sum(B; dims=1)), + DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), B), + DistSpec(A -> MvNormal(to_posdef(A)), (A,), B), + DistSpec(MvLogNormal, (a, b), A, to_positive), + DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), A, to_positive), + DistSpec(MvLogNormal, (a, alpha), A, to_positive), + DistSpec(MvLogNormal, (a,), A, to_positive), + DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), A, to_positive), + DistSpec(s -> MvLogNormal(dim, s), (alpha,), A, to_positive), + DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), B, to_positive), + DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), B, to_positive), + + DistSpec(alpha -> Dirichlet(to_positive(alpha)), (a,), A, to_simplex), ] + # Tests cannot be executed, so cannot be checked with `@test_broken`. broken_multivariate_distributions = DistSpec[ - DistSpec(p -> Multinomial(2, p ./ sum(p)), (fill(0.5, 2),), [2 1; 0 1]), # Dispatch error DistSpec((m, A) -> MvNormalCanon(m, to_posdef(A)), (a, A), b), DistSpec(MvNormalCanon, (a, b), c), @@ -306,13 +324,10 @@ DistSpec(A -> MvNormalCanon(to_posdef(A)), (A,), B), DistSpec(MvNormalCanon, (a,), A), DistSpec(s -> MvNormalCanon(dim, s), (alpha,), A), - # Test failure - DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), B), - DistSpec(A -> MvNormal(to_posdef(A)), (A,), B), - DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), B), - DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), B), ] + # Tests that have a `broken` field can be executed but, according to FiniteDifferences, + # fail to produce the correct result. These tests can be checked with `@test_broken`. matrixvariate_distributions = DistSpec[ # Matrix x DistSpec((n1, n2) -> MatrixBeta(dim, n1, n2), (3.0, 3.0), A, to_beta_mat), @@ -361,21 +376,25 @@ ) ] + # Tests cannot be executed, so cannot be checked with `@test_broken`. broken_matrixvariate_distributions = DistSpec[ - # Other + # TODO no bijector for MatrixNormal + DistSpec(() -> MatrixNormal(dim, dim), (), A, to_posdef, broken=(:Zygote,)), + # TODO different tests are broken on different combinations of backends DistSpec( (A, B, C) -> MatrixNormal(A, to_posdef(B), to_posdef(C)), (A, B, B), C, to_posdef, ), - DistSpec(() -> MatrixNormal(dim, dim), (), A, to_posdef), + # TODO different tests are broken on different combinations of backends DistSpec( (df, A, B, C) -> MatrixTDist(df, A, to_posdef(B), to_posdef(C)), (1.0, A, B, B), C, to_posdef, ), + # TODO different tests are broken on different combinations of backends DistSpec( (n1, n2, A) -> MatrixFDist(n1, n2, to_posdef(A)), (3.0, 3.0, A), @@ -409,6 +428,11 @@ # Broken distributions d.f(d.θ...) isa Union{VonMises,TriangularDist} && continue + # Skellam only fails in these tests with ReverseDiff + # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 + filldist_broken = d.f(d.θ...) isa Skellam ? (:ReverseDiff,) : d.broken + arraydist_broken = d.broken + # Create `filldist` distribution f_filldist = (θ...,) -> filldist(d.f(θ...), n) d_filldist = f_filldist(d.θ...) @@ -429,10 +453,24 @@ # Test AD test_ad( - DistSpec(Symbol(:filldist, " (", d.name, ", $sz)"), f_filldist, d.θ, x) + DistSpec( + Symbol(:filldist, " (", d.name, ", $sz)"), + f_filldist, + d.θ, + x, + d.xtrans; + broken=filldist_broken, + ) ) test_ad( - DistSpec(Symbol(:arraydist, " (", d.name, ", $sz)"), f_arraydist, d.θ, x) + DistSpec( + Symbol(:arraydist, " (", d.name, ", $sz)"), + f_arraydist, + d.θ, + x, + d.xtrans; + broken=arraydist_broken, + ) ) end end @@ -465,10 +503,24 @@ # Test AD test_ad( - DistSpec(Symbol(:filldist, " (", d.name, ", $n)"), f_filldist, d.θ, x_mat) + DistSpec( + Symbol(:filldist, " (", d.name, ", $n)"), + f_filldist, + d.θ, + x_mat, + d.xtrans; + broken=d.broken, + ) ) test_ad( - DistSpec(Symbol(:arraydist, " (", d.name, ", $n)"), f_arraydist, d.θ, x_mat) + DistSpec( + Symbol(:arraydist, " (", d.name, ", $n)"), + f_arraydist, + d.θ, + x_mat, + d.xtrans; + broken=d.broken, + ) ) # Vector of matrices `x` @@ -481,6 +533,8 @@ f_filldist, d.θ, x_vec_of_mat, + d.xtrans; + broken=d.broken, ) ) test_ad( @@ -489,6 +543,8 @@ f_arraydist, d.θ, x_vec_of_mat, + d.xtrans; + broken=d.broken, ) ) end @@ -522,10 +578,24 @@ # Test AD test_ad( - DistSpec(Symbol(:filldist, " (", d.name, ", $n)"), f_filldist, d.θ, x_mat) + DistSpec( + Symbol(:filldist, " (", d.name, ", $n)"), + f_filldist, + d.θ, + x_mat, + d.xtrans; + broken=d.broken, + ) ) test_ad( - DistSpec(Symbol(:arraydist, " (", d.name, ", $n)"), f_arraydist, d.θ, x_mat) + DistSpec( + Symbol(:arraydist, " (", d.name, ", $n)"), + f_arraydist, + d.θ, + x_mat, + d.xtrans; + broken=d.broken, + ) ) # Vector of matrices `x` @@ -538,6 +608,8 @@ f_filldist, d.θ, x_vec_of_mat, + d.xtrans; + broken=d.broken, ) ) test_ad( @@ -546,6 +618,8 @@ f_arraydist, d.θ, x_vec_of_mat, + d.xtrans; + broken=d.broken, ) ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 117d916c..6c9bf4a4 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -2,7 +2,7 @@ const AD = get(ENV, "AD", "All") # Struct of distribution, corresponding parameters, and a sample. -struct DistSpec{VF<:VariateForm,VS<:ValueSupport,F,T,X,G} +struct DistSpec{VF<:VariateForm,VS<:ValueSupport,F,T,X,G,B<:Tuple} name::Symbol f::F "Distribution parameters." @@ -11,19 +11,21 @@ struct DistSpec{VF<:VariateForm,VS<:ValueSupport,F,T,X,G} x::X "Transformation of sample `x`." xtrans::G + "Broken backends" + broken::B end -function DistSpec(f, θ, x, xtrans=nothing) +function DistSpec(f, θ, x, xtrans=nothing; broken=()) name = f isa Distribution ? nameof(typeof(f)) : nameof(typeof(f(θ...))) - return DistSpec(name, f, θ, x, xtrans) + return DistSpec(name, f, θ, x, xtrans; broken=broken) end -function DistSpec(name::Symbol, f, θ, x, xtrans=nothing) +function DistSpec(name::Symbol, f, θ, x, xtrans=nothing; broken=()) F = f isa Distribution ? typeof(f) : typeof(f(θ...)) VF = Distributions.variate_form(F) VS = Distributions.value_support(F) - return DistSpec{VF,VS,typeof(f),typeof(θ),typeof(x),typeof(xtrans)}( - name, f, θ, x, xtrans + return DistSpec{VF,VS,typeof(f),typeof(θ),typeof(x),typeof(xtrans),typeof(broken)}( + name, f, θ, x, xtrans, broken, ) end @@ -90,6 +92,7 @@ function test_ad(dist::DistSpec; kwargs...) θ = dist.θ x = dist.x g = dist.xtrans + broken = dist.broken # Test links d = f(θ...) @@ -107,58 +110,65 @@ function test_ad(dist::DistSpec; kwargs...) end end - if isempty(θ) - # In this case we can only test the gradient with respect to `x` - xtest = vectorize(x) - ftest = let xorig=x - x -> f_allargs(unpack(x, (1,), xorig)...) - end - test_ad(ftest, xtest; kwargs...) - else - # For all combinations of distribution parameters `θ` - for inds in combinations(2:(length(θ) + 1)) - # Test only distribution parameters + # For all combinations of distribution parameters `θ` + for inds in powerset(2:(length(θ) + 1)) + # Test only distribution parameters + if !isempty(inds) xtest = mapreduce(vcat, inds) do i vectorize(θ[i - 1]) end - ftest = let xorig=x, θorig=θ, inds=inds + f_test = let xorig=x, θorig=θ, inds=inds x -> f_allargs(unpack(x, inds, xorig, θorig...)...) end - test_ad(ftest, xtest; kwargs...) - - # Test derivative with respect to location `x` as well - # if the distribution is continuous - if Distributions.value_support(typeof(dist)) === Continuous - xtest = vcat(vectorize(x), xtest) - push!(inds, 1) - ftest = let xorig=x, θorig=θ, inds=inds - x -> f_allargs(unpack(x, inds, xorig, θorig...)...) - end - test_ad(ftest, xtest; kwargs...) + test_ad(f_test, xtest, broken; kwargs...) + end + + # Test derivative with respect to location `x` as well + # if the distribution is continuous + if Distributions.value_support(typeof(dist)) === Continuous + xtest = isempty(inds) ? vectorize(x) : vcat(vectorize(x), xtest) + push!(inds, 1) + f_test = let xorig=x, θorig=θ, inds=inds + x -> f_allargs(unpack(x, inds, xorig, θorig...)...) end + test_ad(f_test, xtest, broken; kwargs...) end end end -function test_ad(f, x; rtol = 1e-6, atol = 1e-6, ad = AD) - finitediff = FiniteDiff.finite_difference_gradient(f, x) +function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) + finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] - if ad == "All" || ad == "ForwardDiff_Tracker" - tracker = Tracker.data(Tracker.gradient(f, x)[1]) - @test tracker ≈ finitediff rtol=rtol atol=atol + if AD == "All" || AD == "Tracker" + if :Tracker in broken + @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + else + @test Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + end + end - forward = ForwardDiff.gradient(f, x) - @test forward ≈ finitediff rtol=rtol atol=atol + if AD == "All" || AD == "ForwardDiff" + if :ForwardDiff in broken + @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + else + @test ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + end end - if ad == "All" || ad == "Zygote" - zygote = Zygote.gradient(f, x)[1] - @test zygote ≈ finitediff rtol=rtol atol=atol + if AD == "All" || AD == "Zygote" + if :Zygote in broken + @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + else + @test Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + end end - if ad == "All" || ad == "ReverseDiff" - reversediff = ReverseDiff.gradient(f, x) - @test reversediff ≈ finitediff rtol=rtol atol=atol + if AD == "All" || AD == "ReverseDiff" + if :ReverseDiff in broken + @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + else + @test ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + end end return diff --git a/test/runtests.jl b/test/runtests.jl index 47f36d95..b317cc4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Bijectors using Combinatorics using DistributionsAD -using FiniteDiff +using FiniteDifferences using ForwardDiff using ReverseDiff using Tracker @@ -16,6 +16,8 @@ using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Per using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial +import NNlib + const is_TRAVIS = haskey(ENV, "TRAVIS") const GROUP = get(ENV, "GROUP", "All")