From ab5caa31a9783872e5748b80eae0d9fcd01f2627 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:31:03 -0600 Subject: [PATCH 1/7] Fix Revolve reverse rule --- Project.toml | 2 +- src/Checkpointing.jl | 1 + src/ChkpDump.jl | 40 +++++++++++++++++++++++++++++++++++++ src/Schemes/Revolve.jl | 45 +++++++++++------------------------------- test/output_chkp.jl | 6 +----- 5 files changed, 54 insertions(+), 40 deletions(-) create mode 100644 src/ChkpDump.jl diff --git a/Project.toml b/Project.toml index 3c4bb83..225a3b6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Checkpointing" uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" authors = ["Michel Schanen ", "Sri Hari Krishna Narayanan "] -version = "0.9.6" +version = "0.9.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index eac728b..69540e8 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -75,6 +75,7 @@ abstract type AbstractStorage end include("Storage/ArrayStorage.jl") include("Storage/HDF5Storage.jl") +include("ChkpDump.jl") export AbstractStorage, ArrayStorage, HDF5Storage diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl new file mode 100644 index 0000000..c47d3c3 --- /dev/null +++ b/src/ChkpDump.jl @@ -0,0 +1,40 @@ +struct ChkpDump + steps::Int64 + period::Int64 + filename::String +end + +ChkpDump(steps, ::Val{false}, period = 1, filename = "chkp") = nothing + +function ChkpDump(steps, ::Val{true}, period = 1, filename = "chkp") + return ChkpDump( + steps, period, filename, + ) +end + +dump_prim(::Nothing, _, _) = nothing + +function dump_prim(chkp::ChkpDump, step, primal) + if (step-1) % chkp.period == 0 + blob = serialize(primal) + open("prim_$(chkp.filename)_$step.chkp", "w") do file + write(file, blob) + end + end +end + +dump_adj(::Nothing, _, _) = nothing + +function dump_adj(chkp::ChkpDump, step, adjoint) + @show step + @show chkp.period + @show step % chkp.period + if (step-1) % chkp.period == 0 + blob = serialize(adjoint) + open("adj_$(chkp.filename)_$step.chkp", "w") do file + write(file, blob) + end + end +end + +read_chkp_file(filename) = deserialize(read(filename)) \ No newline at end of file diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index 0ceeabc..70c5854 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -25,9 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT} frestore::Union{Function,Nothing} storage::AbstractStorage gc::Bool - write_checkpoints::Bool - write_checkpoints_filename::String - write_checkpoints_period::Int + chkp_dump::Union{Nothing, ChkpDump} end function Revolve{MT}( @@ -41,8 +39,8 @@ function Revolve{MT}( verbose::Int = 0, gc::Bool = true, write_checkpoints::Bool = false, - write_checkpoints_filename::String = "chkp.h5", write_checkpoints_period::Int = 1, + write_checkpoints_filename::String = "chkp", ) where {MT} if !isa(anActionInstance, Nothing) # same as default init above @@ -102,9 +100,10 @@ function Revolve{MT}( frestore, storage, gc, - write_checkpoints, - write_checkpoints_filename, - write_checkpoints_period, + ChkpDump( + steps, Val(write_checkpoints), + write_checkpoints_period, write_checkpoints_filename + ), ) if verbose > 0 @@ -437,16 +436,6 @@ function rev_checkpoint_struct_for( if !alg.gc GC.enable(false) end - if alg.write_checkpoints - prim_output = HDF5Storage{MT}( - alg.steps; - filename = "primal_$(alg.write_checkpoints_filename).h5", - ) - adj_output = HDF5Storage{MT}( - alg.steps; - filename = "adjoint_$(alg.write_checkpoints_filename).h5", - ) - end step = alg.steps while true next_action = next_action!(alg) @@ -461,29 +450,21 @@ function rev_checkpoint_struct_for( elseif (next_action.actionflag == Checkpointing.firstuturn) body(model) model_final = deepcopy(model) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - prim_output[step] = model_final - end + dump_prim(alg.chkp_dump, step, model_final) if alg.verbose > 0 @info "Revolve: First Uturn" @info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))" end Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel)) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - adj_output[step] = shadowmodel - end + dump_adj(alg.chkp_dump, step, shadowmodel) step -= 1 if !alg.gc GC.gc() end elseif (next_action.actionflag == Checkpointing.uturn) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - prim_output[step] = model - end + dump_prim(alg.chkp_dump, step, model) Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel)) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - adj_output[step] = shadowmodel - end + dump_adj(alg.chkp_dump, step, shadowmodel) step -= 1 if !alg.gc GC.gc() @@ -505,9 +486,5 @@ function rev_checkpoint_struct_for( if !alg.gc GC.enable(true) end - if alg.write_checkpoints - close(prim_output.fid) - close(adj_output.fid) - end - return model_final + return nothing end diff --git a/test/output_chkp.jl b/test/output_chkp.jl index 909815a..7b6a46c 100644 --- a/test/output_chkp.jl +++ b/test/output_chkp.jl @@ -34,11 +34,7 @@ dx = ChkpOut([0.0, 0.0, 0.0]) g = autodiff(Enzyme.Reverse, loops, Active, Duplicated(x, dx), Const(revolve), Const(iters)) -fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r") +chkp = Checkpointing.deserialize(read("adj_chkp_1.chkp")) # List all checkpoints -saved_chkp = sort(parse.(Int, (keys(fid)))) -println("Checkpoints saved: $saved_chkp") -chkp = Checkpointing.deserialize(read(fid["1"])) @test isa(chkp, ChkpOut) @test all(dx .== chkp.x) -close(fid) From 86e874991fe80ae5503b102cb629d7170bac4358 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:38:14 -0600 Subject: [PATCH 2/7] Update src/ChkpDump.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/ChkpDump.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl index c47d3c3..86c6dff 100644 --- a/src/ChkpDump.jl +++ b/src/ChkpDump.jl @@ -7,9 +7,7 @@ end ChkpDump(steps, ::Val{false}, period = 1, filename = "chkp") = nothing function ChkpDump(steps, ::Val{true}, period = 1, filename = "chkp") - return ChkpDump( - steps, period, filename, - ) + return ChkpDump(steps, period, filename) end dump_prim(::Nothing, _, _) = nothing From 02938ef2dfc420d09417023879008b93a2629fc1 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:38:30 -0600 Subject: [PATCH 3/7] Update src/ChkpDump.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/ChkpDump.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl index 86c6dff..0270746 100644 --- a/src/ChkpDump.jl +++ b/src/ChkpDump.jl @@ -13,7 +13,7 @@ end dump_prim(::Nothing, _, _) = nothing function dump_prim(chkp::ChkpDump, step, primal) - if (step-1) % chkp.period == 0 + if (step - 1) % chkp.period == 0 blob = serialize(primal) open("prim_$(chkp.filename)_$step.chkp", "w") do file write(file, blob) From 53a3dc328e3fbd8f821dccb76fd1697381348322 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:38:34 -0600 Subject: [PATCH 4/7] Update src/ChkpDump.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/ChkpDump.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl index 0270746..8a28079 100644 --- a/src/ChkpDump.jl +++ b/src/ChkpDump.jl @@ -27,7 +27,7 @@ function dump_adj(chkp::ChkpDump, step, adjoint) @show step @show chkp.period @show step % chkp.period - if (step-1) % chkp.period == 0 + if (step - 1) % chkp.period == 0 blob = serialize(adjoint) open("adj_$(chkp.filename)_$step.chkp", "w") do file write(file, blob) From e579962d4bfd54e14f2df4ffc971448f16f47c67 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:38:41 -0600 Subject: [PATCH 5/7] Update src/Schemes/Revolve.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Schemes/Revolve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index 70c5854..a0145fa 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -25,7 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT} frestore::Union{Function,Nothing} storage::AbstractStorage gc::Bool - chkp_dump::Union{Nothing, ChkpDump} + chkp_dump::Union{Nothing,ChkpDump} end function Revolve{MT}( From f2735ba87e93c7a55a6d4fcedb5e1f6319a68c50 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:38:48 -0600 Subject: [PATCH 6/7] Update src/Schemes/Revolve.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Schemes/Revolve.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index a0145fa..db69239 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -101,8 +101,10 @@ function Revolve{MT}( storage, gc, ChkpDump( - steps, Val(write_checkpoints), - write_checkpoints_period, write_checkpoints_filename + steps, + Val(write_checkpoints), + write_checkpoints_period, + write_checkpoints_filename, ), ) From f9ae22e7cd8cb9c9d4851f55e66a8a1283ac732f Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:41:44 -0600 Subject: [PATCH 7/7] Fix --- src/ChkpDump.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl index 8a28079..42898ac 100644 --- a/src/ChkpDump.jl +++ b/src/ChkpDump.jl @@ -35,4 +35,4 @@ function dump_adj(chkp::ChkpDump, step, adjoint) end end -read_chkp_file(filename) = deserialize(read(filename)) \ No newline at end of file +read_chkp_file(filename) = deserialize(read(filename))