From 4782dca38009aebd7fc25445c41134be3a76c9ae Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 15 Sep 2022 17:28:25 -0500 Subject: [PATCH] Bugfixes * Add nothing to the generated function of the loop body to not return any value * Set all entries of shadowmodel to zero --- src/Checkpointing.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index 799c8f3..6dde595 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -115,6 +115,22 @@ function create_tangent(shadowmodel::MT) where {MT} return Tangent{MT,typeof(shadowtuple)}(shadowtuple) end +function set_zero!(nestedmodel::MT) where {MT} + if length(fieldnames(MT)) == 0 + if isreal(nestedmodel) + if isa(nestedmodel, Number) + nestedmodel = zero(MT) + else + fill!(nestedmodel, zero(eltype(nestedmodel))) + end + end + else + for name in fieldnames(MT) + set_zero!(getfield(nestedmodel, name)) + end + end +end + function ChainRulesCore.rrule( ::typeof(Checkpointing.checkpoint_struct_for), body::Function, @@ -128,6 +144,7 @@ function ChainRulesCore.rrule( body(model) end function checkpoint_struct_pullback(dmodel) + set_zero!(shadowmodel) copyto!(shadowmodel, dmodel) model = checkpoint_struct_for(body, alg, model_input, shadowmodel, range) dshadowmodel = create_tangent(shadowmodel) @@ -189,6 +206,7 @@ macro checkpoint_struct(alg, model, loop) end $model = checkpoint_struct_while($alg, $model, shadowmodel, condition) do $model $(loop.args[2]) + nothing end end else @@ -217,6 +235,7 @@ macro checkpoint_struct(alg, model, shadowmodel, loop) end $model = checkpoint_struct_for($alg, $model, $shadowmodel, range) do $model $(loop.args[2]) + nothing end end esc(ex)