diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 9d5d80f5e..646141687 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.5.0" +ADTypes = "1.6.1" ChainRulesCore = "1.23.0" Compat = "3,4" Diffractor = "=0.2.6" diff --git a/DifferentiationInterface/docs/src/backends.md b/DifferentiationInterface/docs/src/backends.md index 408b90700..9820085b4 100644 --- a/DifferentiationInterface/docs/src/backends.md +++ b/DifferentiationInterface/docs/src/backends.md @@ -57,7 +57,7 @@ import Zygote backend_examples = [ AutoDiffractor(), - AutoEnzyme(), + AutoEnzyme(; constant_function=true), AutoFastDifferentiation(), AutoFiniteDiff(), AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), diff --git a/DifferentiationInterface/docs/src/tutorial1.md b/DifferentiationInterface/docs/src/tutorial1.md index 0e3cd1205..d3c26c52c 100644 --- a/DifferentiationInterface/docs/src/tutorial1.md +++ b/DifferentiationInterface/docs/src/tutorial1.md @@ -116,7 +116,7 @@ Typically, for gradients, reverse mode AD might be a better fit, so let's try th ```@example tuto1 import Enzyme -backend2 = AutoEnzyme() +backend2 = AutoEnzyme(constant_function=true) ``` Once the backend is created, things run smoothly with exactly the same syntax as before: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 5095a36b4..3775a8ae6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -38,15 +38,19 @@ using Enzyme: make_zero, make_zero! -struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType +struct AutoDeferredEnzyme{M,constant_function} <: ADTypes.AbstractADType mode::M end ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode)) -DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode) +function DI.nested(backend::AutoEnzyme{M,constant_function}) where {M,constant_function} + return AutoDeferredEnzyme{M,constant_function}(backend.mode) +end -const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}} +const AnyAutoEnzyme{M,constant_function} = Union{ + AutoEnzyme{M,constant_function},AutoDeferredEnzyme{M,constant_function} +} # forward mode if possible forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode @@ -68,6 +72,15 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T return b end +function get_f_and_df(f, ::AnyAutoEnzyme{M,true}) where {M} + return Const(f) +end + +function get_f_and_df(f, ::AnyAutoEnzyme{M,false}) where {M} + df = make_zero(f) + return Duplicated(f, df) +end + include("forward_onearg.jl") include("forward_twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index e5d86f9a8..395a1f0a0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -7,12 +7,13 @@ end function DI.value_and_pushforward( f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras ) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) y, new_dy = if backend isa AutoDeferredEnzyme - autodiff_deferred(forward_mode(backend), f, Duplicated, x_and_dx) + autodiff_deferred(forward_mode(backend), f_and_df, Duplicated, x_and_dx) else - autodiff(forward_mode(backend), Const(f), Duplicated, x_and_dx) + autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx) end return y, new_dy end @@ -20,12 +21,13 @@ end function DI.pushforward( f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras ) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) new_dy = if backend isa AutoDeferredEnzyme - only(autodiff_deferred(forward_mode(backend), f, DuplicatedNoNeed, x_and_dx)) + only(autodiff_deferred(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx)) else - only(autodiff(forward_mode(backend), Const(f), DuplicatedNoNeed, x_and_dx)) + only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx)) end return new_dy end @@ -61,34 +63,42 @@ struct EnzymeForwardGradientExtras{B,O} <: GradientExtras shadow::O end -function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode}, x) +function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode,true}, x) B = pick_batchsize(backend, length(x)) shadow = chunkedonehot(x, Val(B)) return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow) end function DI.gradient( - f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, backend::AutoEnzyme{<:ForwardMode,true}, x, extras::EnzymeForwardGradientExtras{B} ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return reshape(collect(grad_tup), size(x)) end function DI.value_and_gradient( - f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras + f, backend::AutoEnzyme{<:ForwardMode,true}, x, extras::EnzymeForwardGradientExtras ) return f(x), DI.gradient(f, backend, x, extras) end function DI.gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, + grad, + backend::AutoEnzyme{<:ForwardMode,true}, + x, + extras::EnzymeForwardGradientExtras{B}, ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return copyto!(grad, grad_tup) end function DI.value_and_gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, + grad, + backend::AutoEnzyme{<:ForwardMode,true}, + x, + extras::EnzymeForwardGradientExtras{B}, ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return f(x), copyto!(grad, grad_tup) @@ -100,7 +110,7 @@ struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras shadow::O end -function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x) +function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x) B = pick_batchsize(backend, length(x)) shadow = chunkedonehot(x, Val(B)) return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow) @@ -108,7 +118,7 @@ end function DI.jacobian( f, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras{B}, ) where {B} @@ -120,7 +130,7 @@ end function DI.value_and_jacobian( f, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) @@ -130,7 +140,7 @@ end function DI.jacobian!( f, jac, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) @@ -140,7 +150,7 @@ end function DI.value_and_jacobian!( f, jac, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index ed242c50b..e05cb273f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -12,14 +12,15 @@ function DI.value_and_pushforward( dx, ::NoPushforwardExtras, ) + f!_and_df! = get_f_and_df(f!, backend) dx_sametype = convert(typeof(x), dx) dy_sametype = make_zero(y) y_and_dy = Duplicated(y, dy_sametype) x_and_dx = Duplicated(x, dx_sametype) if backend isa AutoDeferredEnzyme - autodiff_deferred(forward_mode(backend), f!, Const, y_and_dy, x_and_dx) + autodiff_deferred(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) else - autodiff(forward_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) end return y, dy_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 071846ae7..35ed56113 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -13,10 +13,11 @@ function DI.value_and_pullback( dy::Number, ::NoPullbackExtras, ) + f_and_df = get_f_and_df(f, backend) der, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, f, Active, Active(x)) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, Active(x)) else - autodiff(ReverseWithPrimal, Const(f), Active, Active(x)) + autodiff(ReverseWithPrimal, f_and_df, Active, Active(x)) end new_dx = dy * only(der) return y, new_dx @@ -29,11 +30,13 @@ function DI.value_and_pullback( dy, ::NoPullbackExtras, ) - tf, tx = typeof(f), typeof(x) - forw, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{tf}, Duplicated, Active{tx}) - tape, y, new_dy = forw(Const(f), Active(x)) + f_and_df = get_f_and_df(f, backend) + forw, rev = autodiff_thunk( + ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(Active(x)) + ) + tape, y, new_dy = forw(f_and_df, Active(x)) copyto!(new_dy, dy) - new_dx = only(only(rev(Const(f), Active(x), tape))) + new_dx = only(only(rev(f_and_df, Active(x), tape))) return y, new_dx end @@ -44,12 +47,13 @@ function DI.value_and_pullback( dy::Number, ::NoPullbackExtras, ) + f_and_df = get_f_and_df(f, backend) dx_sametype = make_zero(x) x_and_dx = Duplicated(x, dx_sametype) _, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, x_and_dx) else - autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff(ReverseWithPrimal, f_and_df, Active, x_and_dx) end if !isone(dy) # TODO: generalize beyond Arrays? @@ -81,13 +85,14 @@ function DI.value_and_pullback!( dy::Number, ::NoPullbackExtras, ) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) x_and_dx = Duplicated(x, dx_sametype) _, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, x_and_dx) else - autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff(ReverseWithPrimal, f_and_df, Active, x_and_dx) end if !isone(dy) # TODO: generalize beyond Arrays? @@ -99,15 +104,16 @@ end function DI.value_and_pullback!( f, dx, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, ::NoPullbackExtras ) - tf, tx = typeof(f), typeof(x) - forw, rev = autodiff_thunk( - ReverseSplitWithPrimal, Const{tf}, Duplicated, Duplicated{tx} - ) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) - tape, y, new_dy = forw(Const(f), Duplicated(x, dx_sametype)) + x_and_dx = Duplicated(x, dx_sametype) + forw, rev = autodiff_thunk( + ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(x_and_dx) + ) + tape, y, new_dy = forw(f_and_df, x_and_dx) copyto!(new_dy, dy) - rev(Const(f), Duplicated(x, dx_sametype), tape) + rev(f_and_df, x_and_dx, tape) return y, copyto!(dx, dx_sametype) end @@ -124,12 +130,12 @@ end ## Gradient -function DI.prepare_gradient(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x) +function DI.prepare_gradient(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x) return NoGradientExtras() end function DI.gradient( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, ::NoGradientExtras ) if backend isa AutoDeferredEnzyme grad = make_zero(x) @@ -143,7 +149,7 @@ end function DI.gradient!( f, grad, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, extras::NoGradientExtras, ) @@ -158,13 +164,17 @@ function DI.gradient!( end function DI.value_and_gradient( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, ::NoGradientExtras ) return DI.value_and_pullback(f, backend, x, true, NoPullbackExtras()) end function DI.value_and_gradient!( - f, grad, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, + grad, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + x, + ::NoGradientExtras, ) return DI.value_and_pullback!(f, grad, backend, x, true, NoPullbackExtras()) end @@ -172,59 +182,3 @@ end ## Jacobian # see https://github.com/EnzymeAD/Enzyme.jl/issues/1391 - -#= - -struct EnzymeReverseOneArgJacobianExtras{B,N} end - -function DI.prepare_jacobian(f, backend::AutoReverseEnzyme, x) - B = pick_batchsize(backend, length(x)) - y = f(x) - N = length(y) - return EnzymeReverseOneArgJacobianExtras{B,N}() -end - -function DI.jacobian( - f, - backend::AutoReverseEnzyme, - x::AbstractArray, - ::EnzymeReverseOneArgJacobianExtras{C,N}, -) where {B,N} - jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(N), Val(B)) - nx = length(x) - ny = length(jac_wrongshape) รท length(x) - jac_rightshape = reshape(jac_wrongshape, ny, nx) - return jac_rightshape -end - -function DI.value_and_jacobian( - f, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - return f(x), DI.jacobian(f, backend, x, extras) -end - -function DI.jacobian!( - f, - jac, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - return copyto!(jac, DI.jacobian(f, backend, x, extras)) -end - -function DI.value_and_jacobian!( - f, - jac, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - y, new_jac = DI.value_and_jacobian(f, backend, x, extras) - return y, copyto!(jac, new_jac) -end - -=# diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 4648e7b6b..c6c93651e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -12,12 +12,13 @@ function DI.value_and_pullback( dy, ::NoPullbackExtras, ) + f!_and_df! = get_f_and_df(f!, backend) dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) _, new_dx = if backend isa AutoDeferredEnzyme - only(autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, Active(x))) + only(autodiff_deferred(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x))) else - only(autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, Active(x))) + only(autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x))) end return y, new_dx end @@ -30,14 +31,15 @@ function DI.value_and_pullback( dy, ::NoPullbackExtras, ) + f!_and_df! = get_f_and_df(f!, backend) dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) x_and_dx = Duplicated(x, dx_sametype) if backend isa AutoDeferredEnzyme - autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, x_and_dx) + autodiff_deferred(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) else - autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) end return y, dx_sametype end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 9a860a603..878cffec3 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -1,19 +1,24 @@ using ADTypes: ADTypes using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT using Enzyme: Enzyme using SparseConnectivityTracer, SparseMatrixColorings using StableRNGs using Test dense_backends = [ - AutoEnzyme(; mode=nothing), - AutoEnzyme(; mode=Enzyme.Forward), - AutoEnzyme(; mode=Enzyme.Reverse), + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), ] nested_dense_backends = [ - DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Forward)), - DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Reverse)), + DifferentiationInterface.nested( + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true) + ), + DifferentiationInterface.nested( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true) + ), ] sparse_backends = @@ -42,10 +47,26 @@ test_differentiation( test_differentiation( [ - AutoEnzyme(; mode=nothing), - AutoEnzyme(; mode=Enzyme.Reverse), - SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Reverse)), - SecondOrder(AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=false), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=false), + ], + DIT.make_closure.(default_scenarios()); + second_order=false, + logging=LOGGING, +); + +test_differentiation( + [ + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + ), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + ), ]; first_order=false, excluded=[:second_derivative], @@ -53,14 +74,17 @@ test_differentiation( ); test_differentiation( - [AutoEnzyme(; mode=nothing), AutoEnzyme(; mode=Enzyme.Forward)]; + [ + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + ]; first_order=false, excluded=[:hessian, :hvp], logging=LOGGING, ); test_differentiation( - AutoEnzyme(; mode=Enzyme.Forward); # TODO: add more + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true); # TODO: add more correctness=false, type_stability=true, second_order=false, diff --git a/DifferentiationInterface/test/Back/SecondOrder/test.jl b/DifferentiationInterface/test/Back/SecondOrder/test.jl index 413d13c21..2ccf530c1 100644 --- a/DifferentiationInterface/test/Back/SecondOrder/test.jl +++ b/DifferentiationInterface/test/Back/SecondOrder/test.jl @@ -16,8 +16,12 @@ onearg_backends = [ ] twoarg_backends = [ - SecondOrder(AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward)), - SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff()), + SecondOrder( + AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward, constant_function=true) + ), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), AutoForwardDiff() + ), ] for backend in vcat(onearg_backends, twoarg_backends) diff --git a/DifferentiationInterface/test/Down/Detector/detector.jl b/DifferentiationInterface/test/Down/Detector/detector.jl index 2125637bc..892c906b8 100644 --- a/DifferentiationInterface/test/Down/Detector/detector.jl +++ b/DifferentiationInterface/test/Down/Detector/detector.jl @@ -24,7 +24,7 @@ g(x::AbstractVector) = dot(x, Hc, x) g(x::AbstractMatrix) = g(vec(x)) @testset verbose = true "$(typeof(backend))" for backend in [ - AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff() + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), AutoForwardDiff() ] @test_throws ArgumentError DenseSparsityDetector(backend; atol=1e-5, method=:random) @testset "$method" for method in (:iterative, :direct) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index c12be09e5..9571ffe49 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -61,6 +61,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -77,4 +78,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterfaceTest/docs/src/tutorial.md b/DifferentiationInterfaceTest/docs/src/tutorial.md index 6ab5f086d..8c88ebbdc 100644 --- a/DifferentiationInterfaceTest/docs/src/tutorial.md +++ b/DifferentiationInterfaceTest/docs/src/tutorial.md @@ -12,7 +12,7 @@ import ForwardDiff, Enzyme The AD backends we want to compare are [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). ```@example tuto -backends = [AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Reverse)] +backends = [AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true)] ``` To do that, we are going to take gradients of a simple function: diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 41eeb320d..4c7ef0fc2 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -76,6 +76,7 @@ include("scenarios/default.jl") include("scenarios/sparse.jl") include("scenarios/allocfree.jl") include("scenarios/extensions.jl") +include("scenarios/modify.jl") include("utils/zero_backends.jl") include("utils/misc.jl") diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl new file mode 100644 index 000000000..db939ccfb --- /dev/null +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -0,0 +1,33 @@ +struct MyClosure{args,F,X,Y} + f::F + x_buffer::Vector{X} + y_buffer::Vector{Y} +end + +function (mc::MyClosure{1})(x) + mc.x_buffer[1] = x + mc.y_buffer[1] = mc.f(x) + return copy(mc.y_buffer[1]) +end + +function (mc::MyClosure{2})(y, x) + mc.x_buffer[1] = x + mc.f(mc.y_buffer[1], mc.x_buffer[1]) + copyto!(y, mc.y_buffer[1]) + return nothing +end + +""" + make_closure(scen::Scenario) + +Return a new [`Scenario`](@ref) with a modified function `f` or `f!` that closes over differentiable data. +""" +function make_closure(scen::Scenario) + @compat (; f, x, y) = scen + x_buffer = [zero(x)] + y_buffer = [zero(y)] + closure_f = MyClosure{nb_args(scen),typeof(f),typeof(x),typeof(y)}( + f, x_buffer, y_buffer + ) + return change_function(scen, closure_f) +end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 6171e3a11..ba86d363b 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -3,6 +3,7 @@ using ComponentArrays: ComponentArrays using DifferentiationInterface using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT +using FiniteDiff: FiniteDiff using FiniteDifferences: FiniteDifferences using Flux: Flux using ForwardDiff: ForwardDiff @@ -20,6 +21,13 @@ test_differentiation( logging=LOGGING, ) +test_differentiation( + AutoFiniteDiff(), + DIT.make_closure.(default_scenarios()); + second_order=false, + logging=LOGGING, +); + test_differentiation( AutoZygote(), gpu_scenarios(); correctness=true, second_order=false, logging=LOGGING )