Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
* Add nothing to the generated function of the loop body to not return
  any value
* Set all entries of shadowmodel to zero
  • Loading branch information
michel2323 committed Sep 15, 2022
1 parent 6fdf9e7 commit 4782dca
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ function create_tangent(shadowmodel::MT) where {MT}
return Tangent{MT,typeof(shadowtuple)}(shadowtuple)
end

function set_zero!(nestedmodel::MT) where {MT}
if length(fieldnames(MT)) == 0
if isreal(nestedmodel)
if isa(nestedmodel, Number)
nestedmodel = zero(MT)
else
fill!(nestedmodel, zero(eltype(nestedmodel)))
end
end
else
for name in fieldnames(MT)
set_zero!(getfield(nestedmodel, name))
end
end
end

function ChainRulesCore.rrule(
::typeof(Checkpointing.checkpoint_struct_for),
body::Function,
Expand All @@ -128,6 +144,7 @@ function ChainRulesCore.rrule(
body(model)
end
function checkpoint_struct_pullback(dmodel)
set_zero!(shadowmodel)
copyto!(shadowmodel, dmodel)
model = checkpoint_struct_for(body, alg, model_input, shadowmodel, range)
dshadowmodel = create_tangent(shadowmodel)
Expand Down Expand Up @@ -189,6 +206,7 @@ macro checkpoint_struct(alg, model, loop)
end
$model = checkpoint_struct_while($alg, $model, shadowmodel, condition) do $model
$(loop.args[2])
nothing
end
end
else
Expand Down Expand Up @@ -217,6 +235,7 @@ macro checkpoint_struct(alg, model, shadowmodel, loop)
end
$model = checkpoint_struct_for($alg, $model, $shadowmodel, range) do $model
$(loop.args[2])
nothing
end
end
esc(ex)
Expand Down

0 comments on commit 4782dca

Please sign in to comment.