From 15fcd433b978e48064c541cbc991635ea12cc7f8 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 6 May 2024 14:12:29 -0500 Subject: [PATCH] Fix primal again --- Project.toml | 2 +- src/Rules/EnzymeRules.jl | 20 ++++++++++---------- test/multilevel.jl | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 7292f3d..4f99223 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Checkpointing" uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" authors = ["Michel Schanen ", "Sri Hari Krishna Narayanan "] -version = "0.9.2" +version = "0.9.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/Rules/EnzymeRules.jl b/src/Rules/EnzymeRules.jl index 97cebfc..28f360a 100644 --- a/src/Rules/EnzymeRules.jl +++ b/src/Rules/EnzymeRules.jl @@ -11,11 +11,12 @@ function augmented_primal( model, range, ) - primal = func.val(body.val, alg.val, deepcopy(model.val), range.val) + tape_model = deepcopy(model.val) + func.val(body.val, alg.val, model.val, range.val) if needs_primal(config) - return AugmentedReturn(primal, nothing, (model.val,)) + return AugmentedReturn(nothing, nothing, (tape_model,)) else - return AugmentedReturn(nothing, nothing, (model.val,)) + return AugmentedReturn(nothing, nothing, (tape_model,)) end end @@ -30,14 +31,13 @@ function reverse( range, ) (model_input,) = tape - model_final = Checkpointing.rev_checkpoint_struct_for( + Checkpointing.rev_checkpoint_struct_for( body.val, alg.val, model_input, model.dval, range.val, ) - copyto!(model.val, model_final) return (nothing, nothing, nothing, nothing) end @@ -50,11 +50,12 @@ function augmented_primal( model, condition, ) - primal = func.val(body.val, alg.val, deepcopy(model.val), condition.val) + tape_model = deepcopy(model.val) + func.val(body.val, alg.val, model.val, condition.val) if needs_primal(config) - return AugmentedReturn(primal, nothing, (model.val,)) + return AugmentedReturn(nothing, nothing, (tape_model,)) else - return AugmentedReturn(nothing, nothing, (model.val,)) + return AugmentedReturn(nothing, nothing, (tape_model,)) end end @@ -69,14 +70,13 @@ function reverse( condition, ) (model_input,) = tape - model_final = Checkpointing.rev_checkpoint_struct_while( + Checkpointing.rev_checkpoint_struct_while( body.val, alg.val, model_input, model.dval, condition.val, ) - copyto!(model.val, model_final) return (nothing, nothing, nothing, nothing) end diff --git a/test/multilevel.jl b/test/multilevel.jl index 2ea3533..55362bf 100644 --- a/test/multilevel.jl +++ b/test/multilevel.jl @@ -43,5 +43,5 @@ g = autodiff( ) # TODO: Primal is wrong only when multilevel checkpointing is used -@test_broken g[2] == primal +@test g[2] == primal @test all(dx.x .== [1024.0, 1024.0, 1024.0])