Skip to content

Commit

Permalink
Fix primal again
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed May 6, 2024
1 parent a84a075 commit 15fcd43
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.9.2"
version = "0.9.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
20 changes: 10 additions & 10 deletions src/Rules/EnzymeRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ function augmented_primal(
model,
range,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), range.val)
tape_model = deepcopy(model.val)
func.val(body.val, alg.val, model.val, range.val)
if needs_primal(config)
return AugmentedReturn(primal, nothing, (model.val,))
return AugmentedReturn(nothing, nothing, (tape_model,))
else
return AugmentedReturn(nothing, nothing, (model.val,))
return AugmentedReturn(nothing, nothing, (tape_model,))
end
end

Expand All @@ -30,14 +31,13 @@ function reverse(
range,
)
(model_input,) = tape
model_final = Checkpointing.rev_checkpoint_struct_for(
Checkpointing.rev_checkpoint_struct_for(
body.val,
alg.val,
model_input,
model.dval,
range.val,
)
copyto!(model.val, model_final)
return (nothing, nothing, nothing, nothing)
end

Expand All @@ -50,11 +50,12 @@ function augmented_primal(
model,
condition,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), condition.val)
tape_model = deepcopy(model.val)
func.val(body.val, alg.val, model.val, condition.val)
if needs_primal(config)
return AugmentedReturn(primal, nothing, (model.val,))
return AugmentedReturn(nothing, nothing, (tape_model,))
else
return AugmentedReturn(nothing, nothing, (model.val,))
return AugmentedReturn(nothing, nothing, (tape_model,))
end
end

Expand All @@ -69,14 +70,13 @@ function reverse(
condition,
)
(model_input,) = tape
model_final = Checkpointing.rev_checkpoint_struct_while(
Checkpointing.rev_checkpoint_struct_while(
body.val,
alg.val,
model_input,
model.dval,
condition.val,
)
copyto!(model.val, model_final)
return (nothing, nothing, nothing, nothing)
end

Expand Down
2 changes: 1 addition & 1 deletion test/multilevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ g = autodiff(
)

# TODO: Primal is wrong only when multilevel checkpointing is used
@test_broken g[2] == primal
@test g[2] == primal
@test all(dx.x .== [1024.0, 1024.0, 1024.0])

0 comments on commit 15fcd43

Please sign in to comment.