Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Secret preparation modifier for resizing #521

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("second_order/hvp.jl")
include("second_order/hessian.jl")

include("fallbacks/no_prep.jl")
include("fallbacks/change_prep.jl")

include("misc/differentiate_with.jl")
include("misc/from_primitive.jl")
Expand Down
125 changes: 125 additions & 0 deletions DifferentiationInterface/src/fallbacks/change_prep.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
for op in [
:derivative,
:gradient,
:jacobian,
:second_derivative,
:hessian,
:pushforward,
:pullback,
:hvp,
]
op! = Symbol(op, "!")
val_and_op = if op == :second_derivative
:value_derivative_and_second_derivative
elseif op == :hessian
:value_gradient_and_hessian
elseif op == :hvp
nothing
else
Symbol("value_and_", op)
end
val_and_op! = Symbol(val_and_op, "!")
prep_op = Symbol("prepare_", op)
prep_op! = Symbol("prepare!_", op)
prep_op_same_point = Symbol("prepare_", op, "_same_point")
P = if op == :derivative
DerivativePrep
elseif op == :gradient
GradientPrep
elseif op == :jacobian
JacobianPrep
elseif op == :second_derivative
SecondDerivativePrep
elseif op == :hessian
HessianPrep
elseif op == :pushforward
PushforwardPrep
elseif op == :pullback
PullbackPrep
elseif op == :hvp
HVPPrep
end

if op in (:derivative, :gradient, :jacobian)
# 1-arg
@eval function $prep_op!(
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
return $prep_op(f, backend, x, contexts...)
end
op == :gradient && continue
# 2-arg
@eval function $prep_op!(
f!::F, y, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
return $prep_op(f!, y, backend, x, contexts...)
end

elseif op in (:second_derivative, :hessian)
# 1-arg
@eval function $prep_op!(
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
return $prep_op(f, backend, x, contexts...)
end

elseif op in (:pushforward, :pullback, :hvp)
# 1-arg
@eval function $prep_op!(
f::F,
::$P,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return $prep_op(f, backend, x, seed, contexts...)
end
@eval function $prep_op_same_point(
f::F,
prep::$P,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return prep
end
@eval function $prep_op_same_point(
f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
) where {F,C}
prep = $prep_op(f, backend, x, seed, contexts...)
return $prep_op_same_point(f, prep, backend, x, seed, contexts...)
end
op == :hvp && continue
# 2-arg
@eval function $prep_op!(
f!::F,
y,
::$P,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return $prep_op(f!, y, backend, x, seed, contexts...)
end
@eval function $prep_op_same_point(
f!::F,
y,
prep::$P,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return prep
end
@eval function $prep_op_same_point(
f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
) where {F,C}
prep = $prep_op(f!, y, backend, x, seed, contexts...)
return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...)
end
end
end
Loading
Loading