Skip to content

Commit

Permalink
Remove BijectorsEnzymeExt on 1.11.1 (#337)
Browse files Browse the repository at this point in the history
* Disable fail-fast on CI

* Inline expanded frule and rrule in BijectorsEnzymeExt

* Bump patch version

* Remove BijectorsEnzymeExt on 1.11.1+

* Tapir -> Mooncake (#338)

* Tapir -> Mooncake

* Bump minor version

* Mark Mooncake test as broken

* Remove BijectorsEnzymeExt on 1.11.1+

* Increase tolerance on `ordered` test
  • Loading branch information
penelopeysm authored Oct 29, 2024
1 parent 9a19a37 commit e0f04fc
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 55 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
Expand All @@ -23,13 +23,13 @@ jobs:
AD:
- Enzyme
- ForwardDiff
- Tapir
- Mooncake
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
AD: Mooncake
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Interface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
Expand Down
15 changes: 9 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.18"
version = "0.14.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -26,21 +26,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = "Enzyme"
BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"]
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsMooncakeExt = "Mooncake"
BijectorsTrackerExt = "Tracker"
BijectorsTapirExt = "Tapir"
BijectorsZygoteExt = "Zygote"

[compat]
Expand All @@ -53,6 +54,7 @@ Distributions = "0.25.33"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
Enzyme = "0.12.22"
EnzymeCore = "0.7.8"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
Expand All @@ -65,17 +67,18 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Mooncake = "0.4.19"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"

[extras]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
14 changes: 9 additions & 5 deletions ext/BijectorsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
module BijectorsEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme: @import_frule, @import_rrule
using Enzyme: @import_rrule, @import_frule
using Bijectors: find_alpha
else
using ..Enzyme: @import_frule, @import_rrule
using ..Enzyme: @import_rrule, @import_frule
using ..Bijectors: find_alpha
end

@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real

@static if v"1.11.1" <= VERSION < v"1.12"
@warn "Bijectors and Enzyme do not work together on Julia $VERSION"
else
@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real
end

end # module
15 changes: 8 additions & 7 deletions ext/BijectorsTapirExt.jl → ext/BijectorsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module BijectorsTapirExt
module BijectorsMooncakeExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Mooncake:
@is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha, ChainRulesCore
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, ChainRulesCore
end

Expand All @@ -19,20 +20,20 @@ end
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
function Mooncake.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
if tangent_type(I) != Mooncake.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
return Mooncake.zero_fcodual(out), find_alpha_pb
end

end
14 changes: 7 additions & 7 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,37 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
if @isdefined Mooncake
rng = Xoshiro(123456)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
end

Expand Down
54 changes: 33 additions & 21 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:Mooncake,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
Expand Down Expand Up @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10"
try
Mooncake.build_rrule(f, x)
catch exc
# TODO(penelopeysm):
# @test_throws AssertionError (expr...) doesn't work, unclear why
@test exc isa AssertionError
end
# TODO: The above @test_throws happens because of
# https://github.com/compintell/Mooncake.jl/issues/319. If that test
# fails, it probably means that the issue was fixed, in which case
# we can remove that block and uncomment the following instead.

# rule = Mooncake.build_rrule(f, x)
# if :Mooncake in broken
# @test_broken (
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# else
# @test(
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# end
end

return nothing
Expand Down
4 changes: 2 additions & 2 deletions test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ end
end
end
# Check that the quantiles are reasonable, i.e. within
# 5 standard errors of the true quantiles (and that the MCSE is
# 6 standard errors of the true quantiles (and that the MCSE is
# not too large).
for i in 1:k
for j in 1:length(qts)
@test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2
@test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j]
@test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j]
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
Pkg.add("Mooncake")
using Mooncake
end

const GROUP = get(ENV, "GROUP", "All")
Expand Down

4 comments on commit e0f04fc

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118313

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" e0f04fc598bb67a615e6171cd8f85e6888d513a5
git push origin v0.14.0

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Disable BijectorsEnzymeExt on Julia 1.11. This means Bijectors and Enzyme can't be used together on 1.11. (Note that as of the time of writing, Enzyme is not working on 1.11 anyway)
  • Replace Tapir with Mooncake.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/118313

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" e0f04fc598bb67a615e6171cd8f85e6888d513a5
git push origin v0.14.0

Please sign in to comment.