diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 8a6151a9d..80b9ff818 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -53,6 +53,7 @@ include("second_order/hvp.jl") include("second_order/hessian.jl") include("fallbacks/no_prep.jl") +include("fallbacks/change_prep.jl") include("misc/differentiate_with.jl") include("misc/from_primitive.jl") diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl new file mode 100644 index 000000000..e399a240c --- /dev/null +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -0,0 +1,125 @@ +for op in [ + :derivative, + :gradient, + :jacobian, + :second_derivative, + :hessian, + :pushforward, + :pullback, + :hvp, +] + op! = Symbol(op, "!") + val_and_op = if op == :second_derivative + :value_derivative_and_second_derivative + elseif op == :hessian + :value_gradient_and_hessian + elseif op == :hvp + nothing + else + Symbol("value_and_", op) + end + val_and_op! = Symbol(val_and_op, "!") + prep_op = Symbol("prepare_", op) + prep_op! = Symbol("prepare!_", op) + prep_op_same_point = Symbol("prepare_", op, "_same_point") + P = if op == :derivative + DerivativePrep + elseif op == :gradient + GradientPrep + elseif op == :jacobian + JacobianPrep + elseif op == :second_derivative + SecondDerivativePrep + elseif op == :hessian + HessianPrep + elseif op == :pushforward + PushforwardPrep + elseif op == :pullback + PullbackPrep + elseif op == :hvp + HVPPrep + end + + if op in (:derivative, :gradient, :jacobian) + # 1-arg + @eval function $prep_op!( + f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + return $prep_op(f, backend, x, contexts...) + end + op == :gradient && continue + # 2-arg + @eval function $prep_op!( + f!::F, y, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + return $prep_op(f!, y, backend, x, contexts...) + end + + elseif op in (:second_derivative, :hessian) + # 1-arg + @eval function $prep_op!( + f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + return $prep_op(f, backend, x, contexts...) + end + + elseif op in (:pushforward, :pullback, :hvp) + # 1-arg + @eval function $prep_op!( + f::F, + ::$P, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + return $prep_op(f, backend, x, seed, contexts...) + end + @eval function $prep_op_same_point( + f::F, + prep::$P, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + return prep + end + @eval function $prep_op_same_point( + f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $prep_op_same_point(f, prep, backend, x, seed, contexts...) + end + op == :hvp && continue + # 2-arg + @eval function $prep_op!( + f!::F, + y, + ::$P, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + return $prep_op(f!, y, backend, x, seed, contexts...) + end + @eval function $prep_op_same_point( + f!::F, + y, + prep::$P, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + return prep + end + @eval function $prep_op_same_point( + f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, seed, contexts...) + return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) + end + end +end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 469a5110a..509a51dae 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -1,217 +1,198 @@ -for op in (:derivative, :gradient, :jacobian) +for op in [ + :derivative, + :gradient, + :jacobian, + :second_derivative, + :hessian, + :pushforward, + :pullback, + :hvp, +] op! = Symbol(op, "!") - val_prefix = "value_and_" - val_and_op = Symbol(val_prefix, op) - val_and_op! = Symbol(val_prefix, op!) - prep_op = Symbol("prepare_", op) - # 1-arg - @eval function $op( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $op(f, prep, backend, x, contexts...) - end - @eval function $op!( - f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $op!(f, result, prep, backend, x, contexts...) - end - @eval function $val_and_op( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $val_and_op(f, prep, backend, x, contexts...) - end - @eval function $val_and_op!( - f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $val_and_op!(f, result, prep, backend, x, contexts...) - end - op == :gradient && continue - # 2-arg - @eval function $op( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) - return $op(f!, y, prep, backend, x, contexts...) - end - @eval function $op!( - f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) - return $op!(f!, y, result, prep, backend, x, contexts...) - end - @eval function $val_and_op( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) - return $val_and_op(f!, y, prep, backend, x, contexts...) - end - @eval function $val_and_op!( - f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) - return $val_and_op!(f!, y, result, prep, backend, x, contexts...) - end -end - -for op in (:second_derivative, :hessian) - op! = Symbol(op, "!") - val_prefix = if op == :second_derivative - "value_derivative_and_" + val_and_op = if op == :second_derivative + :value_derivative_and_second_derivative elseif op == :hessian - "value_gradient_and_" + :value_gradient_and_hessian + elseif op == :hvp + nothing + else + Symbol("value_and_", op) end - val_and_op = Symbol(val_prefix, op) - val_and_op! = Symbol(val_prefix, op!) - prep_op = Symbol("prepare_", op) - # 1-arg - @eval function $op( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $op(f, prep, backend, x, contexts...) - end - @eval function $op!( - f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $op!(f, result2, prep, backend, x, contexts...) - end - @eval function $val_and_op( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $val_and_op(f, prep, backend, x, contexts...) - end - @eval function $val_and_op!( - f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) - return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) - end -end - -for op in (:pushforward, :pullback, :hvp) - op! = Symbol(op, "!") - val_prefix = "value_and_" - val_and_op = Symbol(val_prefix, op) - val_and_op! = Symbol(val_prefix, op!) + val_and_op! = Symbol(val_and_op, "!") prep_op = Symbol("prepare_", op) + prep_op! = Symbol("prepare!_", op) prep_op_same_point = Symbol("prepare_", op, "_same_point") - E = if startswith(string(op), "pushforward") + P = if op == :derivative + DerivativePrep + elseif op == :gradient + GradientPrep + elseif op == :jacobian + JacobianPrep + elseif op == :second_derivative + SecondDerivativePrep + elseif op == :hessian + HessianPrep + elseif op == :pushforward PushforwardPrep - elseif startswith(string(op), "pullback") + elseif op == :pullback PullbackPrep - elseif startswith(string(op), "hvp") + elseif op == :hvp HVPPrep end - # 1-arg - @eval function $prep_op_same_point( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $prep_op_same_point(f, prep, backend, x, seed, contexts...) - end - @eval function $prep_op_same_point( - f::F, - prep::$E, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - return prep - end - @eval function $op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $op(f, prep, backend, x, seed, contexts...) - end - @eval function $op!( - f::F, - result::NTuple, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $op!(f, result, prep, backend, x, seed, contexts...) - end - op == :hvp && continue - @eval function $val_and_op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $val_and_op(f, prep, backend, x, seed, contexts...) - end - @eval function $val_and_op!( - f::F, - result::NTuple, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $val_and_op!(f, result, prep, backend, x, seed, contexts...) - end - # 2-arg - @eval function $prep_op_same_point( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) - end - @eval function $prep_op_same_point( - f!::F, - y, - prep::$E, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - return prep - end - @eval function $op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $op(f!, y, prep, backend, x, seed, contexts...) - end - @eval function $op!( - f!::F, - y, - result::NTuple, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $op!(f!, y, result, prep, backend, x, seed, contexts...) - end - @eval function $val_and_op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} - ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $val_and_op(f!, y, prep, backend, x, seed, contexts...) - end - @eval function $val_and_op!( - f!::F, - y, - result::NTuple, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) + + if op in (:derivative, :jacobian, :gradient) + # 1-arg + @eval function $op( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $op(f, prep, backend, x, contexts...) + end + @eval function $op!( + f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $op!(f, result, prep, backend, x, contexts...) + end + @eval function $val_and_op( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $val_and_op(f, prep, backend, x, contexts...) + end + @eval function $val_and_op!( + f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $val_and_op!(f, result, prep, backend, x, contexts...) + end + op == :gradient && continue + # 2-arg + @eval function $op( + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, contexts...) + return $op(f!, y, prep, backend, x, contexts...) + end + @eval function $op!( + f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, contexts...) + return $op!(f!, y, result, prep, backend, x, contexts...) + end + @eval function $val_and_op( + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, contexts...) + return $val_and_op(f!, y, prep, backend, x, contexts...) + end + @eval function $val_and_op!( + f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, contexts...) + return $val_and_op!(f!, y, result, prep, backend, x, contexts...) + end + + elseif op in (:second_derivative, :hessian) + # 1-arg + @eval function $op( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $op(f, prep, backend, x, contexts...) + end + @eval function $op!( + f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $op!(f, result2, prep, backend, x, contexts...) + end + @eval function $val_and_op( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $val_and_op(f, prep, backend, x, contexts...) + end + @eval function $val_and_op!( + f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, contexts...) + return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) + end + + elseif op in (:pushforward, :pullback, :hvp) + @eval function $op( + f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $op(f, prep, backend, x, seed, contexts...) + end + @eval function $op!( + f::F, + result::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $op!(f, result, prep, backend, x, seed, contexts...) + end + + op == :hvp && continue + + @eval function $val_and_op( + f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $val_and_op(f, prep, backend, x, seed, contexts...) + end + @eval function $val_and_op!( + f::F, + result::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $val_and_op!(f, result, prep, backend, x, seed, contexts...) + end + @eval function $op( + f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, seed, contexts...) + return $op(f!, y, prep, backend, x, seed, contexts...) + end + @eval function $op!( + f!::F, + y, + result::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f!, y, backend, x, seed, contexts...) + return $op!(f!, y, result, prep, backend, x, seed, contexts...) + end + @eval function $val_and_op( + f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + ) where {F,C} + prep = $prep_op(f!, y, backend, x, seed, contexts...) + return $val_and_op(f!, y, prep, backend, x, seed, contexts...) + end + @eval function $val_and_op!( + f!::F, + y, + result::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f!, y, backend, x, seed, contexts...) + return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) + end end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 6737064c7..b5de7760d 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -12,6 +12,19 @@ Create a `prep` object that can be given to [`derivative`](@ref) and its variant """ function prepare_derivative end +""" + prepare!_derivative(f, prep, backend, x, [contexts...]) -> new_prep + prepare!_derivative(f!, y, prep, backend, x, [contexts...]) -> new_prep + +Same behavior as [`prepare_derivative`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_derivative end + """ value_and_derivative(f, [prep,] backend, x, [contexts...]) -> (y, der) value_and_derivative(f!, y, [prep,] backend, x, [contexts...]) -> (y, der) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 4789c93e2..af7724d54 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -10,6 +10,18 @@ Create a `prep` object that can be given to [`gradient`](@ref) and its variants. """ function prepare_gradient end +""" + prepare!_gradient(f, prep, backend, x, [contexts...]) -> new_prep + +Same behavior as [`prepare_gradient`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_gradient end + """ value_and_gradient(f, [prep,] backend, x, [contexts...]) -> (y, grad) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index acff2e9b7..04dc281df 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -12,6 +12,19 @@ Create a `prep` object that can be given to [`jacobian`](@ref) and its variants. """ function prepare_jacobian end +""" + prepare!_jacobian(f, prep, backend, x, [contexts...]) -> new_prep + prepare!_jacobian(f!, y, prep, backend, x, [contexts...]) -> new_prep + +Same behavior as [`prepare_jacobian`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_jacobian end + """ value_and_jacobian(f, [prep,] backend, x, [contexts...]) -> (y, jac) value_and_jacobian(f!, y, [prep,] backend, x, [contexts...]) -> (y, jac) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 5828f16fe..90a166369 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -12,6 +12,19 @@ Create a `prep` object that can be given to [`pullback`](@ref) and its variants. """ function prepare_pullback end +""" + prepare!_pullback(f, prep, backend, x, ty, [contexts...]) -> new_prep + prepare!_pullback(f!, y, prep, backend, x, ty, [contexts...]) -> new_prep + +Same behavior as [`prepare_pullback`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_pullback end + """ prepare_pullback_same_point(f, backend, x, ty, [contexts...]) -> prep_same prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]) -> prep_same diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 2b29b3f09..c8a54c2bc 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -12,6 +12,19 @@ Create a `prep` object that can be given to [`pushforward`](@ref) and its varian """ function prepare_pushforward end +""" + prepare!_pushforward(f, prep, backend, x, tx, [contexts...]) -> new_prep + prepare!_pushforward(f!, y, prep, backend, x, tx, [contexts...]) -> new_prep + +Same behavior as [`prepare_pushforward`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_pushforward end + """ prepare_pushforward_same_point(f, backend, x, tx, [contexts...]) -> prep_same prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]) -> prep_same diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index b5e9c6099..92b6d22d5 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -10,6 +10,18 @@ Create a `prep` object that can be given to [`hessian`](@ref) and its variants. """ function prepare_hessian end +""" + prepare!_hessian(f, backend, x, [contexts...]) -> new_prep + +Same behavior as [`prepare_hessian`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_hessian end + """ hessian(f, [prep,] backend, x, [contexts...]) -> hess diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 86b0ef2cd..97b1171e7 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -10,6 +10,18 @@ Create a `prep` object that can be given to [`hvp`](@ref) and its variants. """ function prepare_hvp end +""" + prepare!_hvp(f, backend, x, tx, [contexts...]) -> new_prep + +Same behavior as [`prepare_hvp`](@ref) but can modify an existing `prep` object to avoid some allocations. + +There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + +!!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +""" +function prepare!_hvp end + """ prepare_hvp_same_point(f, backend, x, tx, [contexts...]) -> prep_same diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 13d37cac4..73f20e8d7 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -28,6 +28,14 @@ using DifferentiationInterface: mode, outer, inplace_support, + prepare!_derivative, + prepare!_gradient, + prepare!_hessian, + prepare!_hvp, + prepare!_jacobian, + prepare!_pullback, + prepare!_pushforward, + prepare!_second_derivative, pushforward_performance, pullback_performance, unwrap diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 43378b5b2..901b70269 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -19,9 +19,10 @@ for op in [ val_and_op = Symbol(val_prefix, op) val_and_op! = Symbol(val_prefix, op!) prep_op = Symbol("prepare_", op) + prep_op! = Symbol("prepare!_", op) prep_op_same = Symbol("prepare_", op, "_same_point") - E = if op == :derivative + P = if op == :derivative DerivativePrep elseif op == :gradient GradientPrep @@ -58,7 +59,15 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, ba, xrand, contextsrand...),)] + prep = $prep_op(f, ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -70,7 +79,7 @@ for op in [ res1_out1_noval = $op(f, preptup_noval..., ba, x, contexts...) res1_out2_noval = $op(f, preptup_noval..., ba, x, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test res1_out1_val ≈ scen.res1 @@ -96,7 +105,15 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, ba, xrand, contextsrand...),)] + prep = $prep_op(f, ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -116,7 +133,7 @@ for op in [ f, res1_in2_noval, preptup_noval..., ba, x, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test res1_in1_val ≈ scen.res1 @@ -148,7 +165,16 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, yrand, ba, xrand, contextsrand...),)] + prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + copy(yrand), + $prep_op(f, copy(yrand), ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -164,7 +190,7 @@ for op in [ res1_out1_noval = $op(f, y_in1_noval, preptup_noval..., ba, x, contexts...) res1_out2_noval = $op(f, y_in2_noval, preptup_noval..., ba, x, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_in1_val ≈ scen.y @test y_in2_val ≈ scen.y @test y_out1_val ≈ scen.y @@ -192,7 +218,16 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, yrand, ba, xrand, contextsrand...),)] + prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + copy(yrand), + $prep_op(f, copy(yrand), ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -212,7 +247,7 @@ for op in [ f, y_in2_noval, res1_in2_noval, preptup_noval..., ba, x, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_in1_val ≈ scen.y @test y_in2_val ≈ scen.y @test y_out1_val ≈ scen.y @@ -245,7 +280,15 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, ba, xrand, contextsrand...),)] + prep = $prep_op(f, ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val, res2_out1_val = $val_and_op( @@ -257,7 +300,7 @@ for op in [ res2_out1_noval = $op(f, preptup_noval..., ba, x, contexts...) res2_out2_noval = $op(f, preptup_noval..., ba, x, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test res1_out1_val ≈ scen.res1 @@ -285,7 +328,15 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [(), ($prep_op(f, ba, xrand, contextsrand...),)] + prep = $prep_op(f, ba, xrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, contextsrand...), + ba, + xrand, + contextsrand..., + ) + [(), (prep,), (prepprep,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) @@ -305,7 +356,7 @@ for op in [ f, res2_in2_noval, preptup_noval..., ba, x, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test res1_in1_val ≈ scen.res1 @@ -340,11 +391,17 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [ - (), - ($prep_op(f, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -356,7 +413,7 @@ for op in [ res1_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) res1_out2_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test all(res1_out1_val .≈ scen.res1) @@ -382,11 +439,17 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [ - (), - ($prep_op(f, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -406,7 +469,7 @@ for op in [ f, res1_in2_noval, preptup_noval..., ba, x, tang, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y @test all(res1_in1_val .≈ scen.res1) @@ -436,11 +499,18 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [ - (), - ($prep_op(f, yrand, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, yrand, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + copy(yrand), + $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -460,7 +530,7 @@ for op in [ f, y_in2_noval, preptup_noval..., ba, x, tang, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_in1_val ≈ scen.y @test y_in2_val ≈ scen.y @test y_out1_val ≈ scen.y @@ -488,11 +558,18 @@ for op in [ rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) preptup_cands_val, preptup_cands_noval = map(1:2) do _ - [ - (), - ($prep_op(f, yrand, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, yrand, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + copy(yrand), + $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -526,7 +603,7 @@ for op in [ contexts..., ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_in1_val ≈ scen.y @test y_in2_val ≈ scen.y @test y_out1_val ≈ scen.y @@ -558,16 +635,22 @@ for op in [ xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) - preptup_cands_noval = [ - (), - ($prep_op(f, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + preptup_cands_noval = [(), (prep,), (prepprep,), (prep_same,)] for preptup_noval in preptup_cands_noval res2_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) res2_out2_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test all(res2_out1_noval .≈ scen.res2) @test all(res2_out2_noval .≈ scen.res2) end @@ -588,11 +671,17 @@ for op in [ xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) - preptup_cands_noval = [ - (), - ($prep_op(f, ba, xrand, tangrand, contextsrand...),), - ($prep_op_same(f, ba, x, tangrand, contexts...),), - ] + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + preptup_cands_noval = [(), (prep,), (prepprep,), (prep_same,)] for preptup_noval in preptup_cands_noval res2_in1_noval = mysimilar(res2) res2_in2_noval = mysimilar(res2) @@ -603,7 +692,7 @@ for op in [ f, res2_in2_noval, preptup_noval..., ba, x, tang, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) - @test isempty(preptup_noval) || only(preptup_noval) isa $E + @test isempty(preptup_noval) || only(preptup_noval) isa $P @test all(res2_in1_noval .≈ scen.res2) @test all(res2_in2_noval .≈ scen.res2) @test all(res2_out1_noval .≈ scen.res2)