From fd8c2109103428cf6b6483cd469301d626ef8dd4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 27 Oct 2024 09:46:43 +0100 Subject: [PATCH 1/4] Improve ForwardDiff tagging --- .../DifferentiationInterfaceForwardDiffExt.jl | 3 + .../secondorder.jl | 118 +++++++++++------- .../utils.jl | 5 + .../src/first_order/gradient.jl | 11 ++ DifferentiationInterface/src/utils/context.jl | 10 ++ 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index e24b4eec8..b213acf6f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -7,7 +7,9 @@ using DifferentiationInterface: BatchSizeSettings, Cache, Constant, + PrepContext, Context, + FixTail, DerivativePrep, DifferentiateWith, GradientPrep, @@ -21,6 +23,7 @@ using DifferentiationInterface: SecondOrder, inner, outer, + shuffled_gradient, unwrap, with_contexts import ForwardDiff.DiffResults as DR diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index e0c84459a..7ae66f7ba 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -1,23 +1,6 @@ -struct ForwardDiffOverSomethingHVPWrapper{F} - f::F -end - -""" - tag_backend_hvp(f, ::AutoForwardDiff, x) - -Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like. -""" -tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend - -function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} - tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x)) - return AutoForwardDiff{chunksize,typeof(tag)}(tag) -end - -struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep - tagged_outer_backend::B - inner_gradient::G - outer_pushforward_prep::E +struct ForwardDiffOverSomethingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep + inner_gradient_prep::E1 + outer_pushforward_prep::E2 end function DI.prepare_hvp( @@ -27,35 +10,42 @@ function DI.prepare_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - rewrap = Rewrap(contexts...) - tagged_outer_backend = tag_backend_hvp(f, outer(backend), x) - T = tag_type(f, tagged_outer_backend, x) + T = tag_type(shuffled_gradient, outer(backend), x) xdual = make_dual(T, x, tx) - gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...) - # TODO: get rid of closure? - function inner_gradient(x, unannotated_contexts...) - annotated_contexts = rewrap(unannotated_contexts...) - return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...) - end - outer_pushforward_prep = DI.prepare_pushforward( - inner_gradient, tagged_outer_backend, x, tx, contexts... + inner_gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...) + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., ) - return ForwardDiffOverSomethingHVPPrep( - tagged_outer_backend, inner_gradient, outer_pushforward_prep + outer_pushforward_prep = DI.prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... ) + return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) end function DI.hvp( f::F, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) return DI.pushforward( - inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end @@ -63,14 +53,28 @@ function DI.hvp!( f::F, tg::NTuple, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep - DI.pushforward!( - inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) + return DI.pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) return tg end @@ -78,14 +82,22 @@ end function DI.gradient_and_hvp( f::F, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) return DI.value_and_pushforward( - inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end @@ -94,14 +106,28 @@ function DI.gradient_and_hvp!( grad, tg::NTuple, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) new_grad, _ = DI.value_and_pushforward!( - inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) return copyto!(grad, new_grad), tg end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index a5e90757f..65cf8a9f1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -31,6 +31,10 @@ function get_tag(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksi return Tag(f, eltype(x)) end +function get_tag(ft::FixTail, ::AutoForwardDiff{chunksize,Nothing}, x) where {chunksize} + return Tag(ft.f, eltype(x)) +end + tag_type(f::F, backend::AutoForwardDiff, x) where {F} = typeof(get_tag(f, backend, x)) function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T,B} @@ -85,6 +89,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end _translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c) +_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = unwrap(c) function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B} c0 = unwrap(c) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 0bd10f381..13a260b02 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -128,3 +128,14 @@ function shuffled_gradient( ) where {F,C} return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end + +function shuffled_gradient( + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) +end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 3eb2ef879..5de2c21db 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -74,6 +74,16 @@ unwrap(c::Cache) = c.data Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data +struct PrepContext{T<:Prep} <: Context + data::T +end + +prepcontext_maker(c) = PrepContext(c) +maker(::PrepContext) = prepcontext_maker +unwrap(c::PrepContext) = c.data + +Base.:(==)(c1::PrepContext, c2::PrepContext) = c1.data == c2.data + struct Rewrap{C,T} context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C} From 1cd2168f73486cb028a955e661d665c0aef3e7b2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 27 Oct 2024 17:38:19 +0100 Subject: [PATCH 2/4] Remove tag unwrapping for FixTail --- .../ext/DifferentiationInterfaceForwardDiffExt/utils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 65cf8a9f1..5cf063410 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -31,10 +31,6 @@ function get_tag(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksi return Tag(f, eltype(x)) end -function get_tag(ft::FixTail, ::AutoForwardDiff{chunksize,Nothing}, x) where {chunksize} - return Tag(ft.f, eltype(x)) -end - tag_type(f::F, backend::AutoForwardDiff, x) where {F} = typeof(get_tag(f, backend, x)) function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T,B} From ec21e8911f93f7d2d839e9046bb900b56bb8d652 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 27 Oct 2024 18:27:35 +0100 Subject: [PATCH 3/4] Cov --- DifferentiationInterface/src/utils/context.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 5de2c21db..310017490 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -78,12 +78,8 @@ struct PrepContext{T<:Prep} <: Context data::T end -prepcontext_maker(c) = PrepContext(c) -maker(::PrepContext) = prepcontext_maker unwrap(c::PrepContext) = c.data -Base.:(==)(c1::PrepContext, c2::PrepContext) = c1.data == c2.data - struct Rewrap{C,T} context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C} From 04d28d32872dc4e9dc9eefe91283adf15ab7c663 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 30 Oct 2024 06:25:41 +0100 Subject: [PATCH 4/4] Bump DI --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 04a526dc8..5a78d930b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.17" +version = "0.6.18" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"