From 346834e1f95a90b21a161620e53a3cf5bee95de0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 30 Oct 2024 06:54:09 +0100 Subject: [PATCH] fix: improve ForwardDiff tagging for HVP (#596) * Improve ForwardDiff tagging * Remove tag unwrapping for FixTail * Cov * Bump DI --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceForwardDiffExt.jl | 3 + .../secondorder.jl | 118 +++++++++++------- .../utils.jl | 1 + .../src/first_order/gradient.jl | 11 ++ DifferentiationInterface/src/utils/context.jl | 6 + 6 files changed, 94 insertions(+), 47 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 04a526dc8..5a78d930b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.17" +version = "0.6.18" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index e24b4eec8..b213acf6f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -7,7 +7,9 @@ using DifferentiationInterface: BatchSizeSettings, Cache, Constant, + PrepContext, Context, + FixTail, DerivativePrep, DifferentiateWith, GradientPrep, @@ -21,6 +23,7 @@ using DifferentiationInterface: SecondOrder, inner, outer, + shuffled_gradient, unwrap, with_contexts import ForwardDiff.DiffResults as DR diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index e0c84459a..7ae66f7ba 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -1,23 +1,6 @@ -struct ForwardDiffOverSomethingHVPWrapper{F} - f::F -end - -""" - tag_backend_hvp(f, ::AutoForwardDiff, x) - -Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like. -""" -tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend - -function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} - tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x)) - return AutoForwardDiff{chunksize,typeof(tag)}(tag) -end - -struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep - tagged_outer_backend::B - inner_gradient::G - outer_pushforward_prep::E +struct ForwardDiffOverSomethingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep + inner_gradient_prep::E1 + outer_pushforward_prep::E2 end function DI.prepare_hvp( @@ -27,35 +10,42 @@ function DI.prepare_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - rewrap = Rewrap(contexts...) - tagged_outer_backend = tag_backend_hvp(f, outer(backend), x) - T = tag_type(f, tagged_outer_backend, x) + T = tag_type(shuffled_gradient, outer(backend), x) xdual = make_dual(T, x, tx) - gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...) - # TODO: get rid of closure? - function inner_gradient(x, unannotated_contexts...) - annotated_contexts = rewrap(unannotated_contexts...) - return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...) - end - outer_pushforward_prep = DI.prepare_pushforward( - inner_gradient, tagged_outer_backend, x, tx, contexts... + inner_gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...) + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., ) - return ForwardDiffOverSomethingHVPPrep( - tagged_outer_backend, inner_gradient, outer_pushforward_prep + outer_pushforward_prep = DI.prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... ) + return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) end function DI.hvp( f::F, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) return DI.pushforward( - inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end @@ -63,14 +53,28 @@ function DI.hvp!( f::F, tg::NTuple, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep - DI.pushforward!( - inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) + return DI.pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) return tg end @@ -78,14 +82,22 @@ end function DI.gradient_and_hvp( f::F, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) return DI.value_and_pushforward( - inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end @@ -94,14 +106,28 @@ function DI.gradient_and_hvp!( grad, tg::NTuple, prep::ForwardDiffOverSomethingHVPPrep, - ::SecondOrder{<:AutoForwardDiff}, + backend::SecondOrder{<:AutoForwardDiff}, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + PrepContext(inner_gradient_prep), + Constant(inner(backend)), + Constant(rewrap), + contexts..., + ) new_grad, _ = DI.value_and_pushforward!( - inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) return copyto!(grad, new_grad), tg end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index a5e90757f..5cf063410 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -85,6 +85,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end _translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c) +_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = unwrap(c) function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B} c0 = unwrap(c) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 0bd10f381..13a260b02 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -128,3 +128,14 @@ function shuffled_gradient( ) where {F,C} return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end + +function shuffled_gradient( + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) +end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 3eb2ef879..310017490 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -74,6 +74,12 @@ unwrap(c::Cache) = c.data Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data +struct PrepContext{T<:Prep} <: Context + data::T +end + +unwrap(c::PrepContext) = c.data + struct Rewrap{C,T} context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C}