Skip to content

Commit

Permalink
Fix Revolve reverse rule
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Nov 25, 2024
1 parent dd2a45a commit ab5caa3
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.9.6"
version = "0.9.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ abstract type AbstractStorage end

include("Storage/ArrayStorage.jl")
include("Storage/HDF5Storage.jl")
include("ChkpDump.jl")

export AbstractStorage, ArrayStorage, HDF5Storage

Expand Down
40 changes: 40 additions & 0 deletions src/ChkpDump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
struct ChkpDump
steps::Int64
period::Int64
filename::String
end

ChkpDump(steps, ::Val{false}, period = 1, filename = "chkp") = nothing

function ChkpDump(steps, ::Val{true}, period = 1, filename = "chkp")
return ChkpDump(
steps, period, filename,
)
end

dump_prim(::Nothing, _, _) = nothing

function dump_prim(chkp::ChkpDump, step, primal)
if (step-1) % chkp.period == 0
blob = serialize(primal)
open("prim_$(chkp.filename)_$step.chkp", "w") do file
write(file, blob)
end
end
end

dump_adj(::Nothing, _, _) = nothing

function dump_adj(chkp::ChkpDump, step, adjoint)
@show step
@show chkp.period
@show step % chkp.period
if (step-1) % chkp.period == 0
blob = serialize(adjoint)
open("adj_$(chkp.filename)_$step.chkp", "w") do file
write(file, blob)
end
end
end

read_chkp_file(filename) = deserialize(read(filename))
45 changes: 11 additions & 34 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT}
frestore::Union{Function,Nothing}
storage::AbstractStorage
gc::Bool
write_checkpoints::Bool
write_checkpoints_filename::String
write_checkpoints_period::Int
chkp_dump::Union{Nothing, ChkpDump}
end

function Revolve{MT}(
Expand All @@ -41,8 +39,8 @@ function Revolve{MT}(
verbose::Int = 0,
gc::Bool = true,
write_checkpoints::Bool = false,
write_checkpoints_filename::String = "chkp.h5",
write_checkpoints_period::Int = 1,
write_checkpoints_filename::String = "chkp",
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
Expand Down Expand Up @@ -102,9 +100,10 @@ function Revolve{MT}(
frestore,
storage,
gc,
write_checkpoints,
write_checkpoints_filename,
write_checkpoints_period,
ChkpDump(
steps, Val(write_checkpoints),
write_checkpoints_period, write_checkpoints_filename
),
)

if verbose > 0
Expand Down Expand Up @@ -437,16 +436,6 @@ function rev_checkpoint_struct_for(
if !alg.gc
GC.enable(false)
end
if alg.write_checkpoints
prim_output = HDF5Storage{MT}(
alg.steps;
filename = "primal_$(alg.write_checkpoints_filename).h5",
)
adj_output = HDF5Storage{MT}(
alg.steps;
filename = "adjoint_$(alg.write_checkpoints_filename).h5",
)
end
step = alg.steps
while true
next_action = next_action!(alg)
Expand All @@ -461,29 +450,21 @@ function rev_checkpoint_struct_for(
elseif (next_action.actionflag == Checkpointing.firstuturn)
body(model)
model_final = deepcopy(model)
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
prim_output[step] = model_final
end
dump_prim(alg.chkp_dump, step, model_final)
if alg.verbose > 0
@info "Revolve: First Uturn"
@info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))"
end
Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel))
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
adj_output[step] = shadowmodel
end
dump_adj(alg.chkp_dump, step, shadowmodel)
step -= 1
if !alg.gc
GC.gc()
end
elseif (next_action.actionflag == Checkpointing.uturn)
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
prim_output[step] = model
end
dump_prim(alg.chkp_dump, step, model)
Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel))
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
adj_output[step] = shadowmodel
end
dump_adj(alg.chkp_dump, step, shadowmodel)
step -= 1
if !alg.gc
GC.gc()
Expand All @@ -505,9 +486,5 @@ function rev_checkpoint_struct_for(
if !alg.gc
GC.enable(true)
end
if alg.write_checkpoints
close(prim_output.fid)
close(adj_output.fid)
end
return model_final
return nothing
end
6 changes: 1 addition & 5 deletions test/output_chkp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ dx = ChkpOut([0.0, 0.0, 0.0])

g = autodiff(Enzyme.Reverse, loops, Active, Duplicated(x, dx), Const(revolve), Const(iters))

fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r")
chkp = Checkpointing.deserialize(read("adj_chkp_1.chkp"))
# List all checkpoints
saved_chkp = sort(parse.(Int, (keys(fid))))
println("Checkpoints saved: $saved_chkp")
chkp = Checkpointing.deserialize(read(fid["1"]))
@test isa(chkp, ChkpOut)
@test all(dx .== chkp.x)
close(fid)

0 comments on commit ab5caa3

Please sign in to comment.