From ab5caa31a9783872e5748b80eae0d9fcd01f2627 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Mon, 25 Nov 2024 13:31:03 -0600 Subject: [PATCH] Fix Revolve reverse rule --- Project.toml | 2 +- src/Checkpointing.jl | 1 + src/ChkpDump.jl | 40 +++++++++++++++++++++++++++++++++++++ src/Schemes/Revolve.jl | 45 +++++++++++------------------------------- test/output_chkp.jl | 6 +----- 5 files changed, 54 insertions(+), 40 deletions(-) create mode 100644 src/ChkpDump.jl diff --git a/Project.toml b/Project.toml index 3c4bb83..225a3b6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Checkpointing" uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" authors = ["Michel Schanen ", "Sri Hari Krishna Narayanan "] -version = "0.9.6" +version = "0.9.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index eac728b..69540e8 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -75,6 +75,7 @@ abstract type AbstractStorage end include("Storage/ArrayStorage.jl") include("Storage/HDF5Storage.jl") +include("ChkpDump.jl") export AbstractStorage, ArrayStorage, HDF5Storage diff --git a/src/ChkpDump.jl b/src/ChkpDump.jl new file mode 100644 index 0000000..c47d3c3 --- /dev/null +++ b/src/ChkpDump.jl @@ -0,0 +1,40 @@ +struct ChkpDump + steps::Int64 + period::Int64 + filename::String +end + +ChkpDump(steps, ::Val{false}, period = 1, filename = "chkp") = nothing + +function ChkpDump(steps, ::Val{true}, period = 1, filename = "chkp") + return ChkpDump( + steps, period, filename, + ) +end + +dump_prim(::Nothing, _, _) = nothing + +function dump_prim(chkp::ChkpDump, step, primal) + if (step-1) % chkp.period == 0 + blob = serialize(primal) + open("prim_$(chkp.filename)_$step.chkp", "w") do file + write(file, blob) + end + end +end + +dump_adj(::Nothing, _, _) = nothing + +function dump_adj(chkp::ChkpDump, step, adjoint) + @show step + @show chkp.period + @show step % chkp.period + if (step-1) % chkp.period == 0 + blob = serialize(adjoint) + open("adj_$(chkp.filename)_$step.chkp", "w") do file + write(file, blob) + end + end +end + +read_chkp_file(filename) = deserialize(read(filename)) \ No newline at end of file diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index 0ceeabc..70c5854 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -25,9 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT} frestore::Union{Function,Nothing} storage::AbstractStorage gc::Bool - write_checkpoints::Bool - write_checkpoints_filename::String - write_checkpoints_period::Int + chkp_dump::Union{Nothing, ChkpDump} end function Revolve{MT}( @@ -41,8 +39,8 @@ function Revolve{MT}( verbose::Int = 0, gc::Bool = true, write_checkpoints::Bool = false, - write_checkpoints_filename::String = "chkp.h5", write_checkpoints_period::Int = 1, + write_checkpoints_filename::String = "chkp", ) where {MT} if !isa(anActionInstance, Nothing) # same as default init above @@ -102,9 +100,10 @@ function Revolve{MT}( frestore, storage, gc, - write_checkpoints, - write_checkpoints_filename, - write_checkpoints_period, + ChkpDump( + steps, Val(write_checkpoints), + write_checkpoints_period, write_checkpoints_filename + ), ) if verbose > 0 @@ -437,16 +436,6 @@ function rev_checkpoint_struct_for( if !alg.gc GC.enable(false) end - if alg.write_checkpoints - prim_output = HDF5Storage{MT}( - alg.steps; - filename = "primal_$(alg.write_checkpoints_filename).h5", - ) - adj_output = HDF5Storage{MT}( - alg.steps; - filename = "adjoint_$(alg.write_checkpoints_filename).h5", - ) - end step = alg.steps while true next_action = next_action!(alg) @@ -461,29 +450,21 @@ function rev_checkpoint_struct_for( elseif (next_action.actionflag == Checkpointing.firstuturn) body(model) model_final = deepcopy(model) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - prim_output[step] = model_final - end + dump_prim(alg.chkp_dump, step, model_final) if alg.verbose > 0 @info "Revolve: First Uturn" @info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))" end Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel)) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - adj_output[step] = shadowmodel - end + dump_adj(alg.chkp_dump, step, shadowmodel) step -= 1 if !alg.gc GC.gc() end elseif (next_action.actionflag == Checkpointing.uturn) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - prim_output[step] = model - end + dump_prim(alg.chkp_dump, step, model) Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel)) - if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 - adj_output[step] = shadowmodel - end + dump_adj(alg.chkp_dump, step, shadowmodel) step -= 1 if !alg.gc GC.gc() @@ -505,9 +486,5 @@ function rev_checkpoint_struct_for( if !alg.gc GC.enable(true) end - if alg.write_checkpoints - close(prim_output.fid) - close(adj_output.fid) - end - return model_final + return nothing end diff --git a/test/output_chkp.jl b/test/output_chkp.jl index 909815a..7b6a46c 100644 --- a/test/output_chkp.jl +++ b/test/output_chkp.jl @@ -34,11 +34,7 @@ dx = ChkpOut([0.0, 0.0, 0.0]) g = autodiff(Enzyme.Reverse, loops, Active, Duplicated(x, dx), Const(revolve), Const(iters)) -fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r") +chkp = Checkpointing.deserialize(read("adj_chkp_1.chkp")) # List all checkpoints -saved_chkp = sort(parse.(Int, (keys(fid)))) -println("Checkpoints saved: $saved_chkp") -chkp = Checkpointing.deserialize(read(fid["1"])) @test isa(chkp, ChkpOut) @test all(dx .== chkp.x) -close(fid)