Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Revolve reverse rule #60

Merged
merged 7 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.9.6"
version = "0.9.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ abstract type AbstractStorage end

include("Storage/ArrayStorage.jl")
include("Storage/HDF5Storage.jl")
include("ChkpDump.jl")

export AbstractStorage, ArrayStorage, HDF5Storage

Expand Down
38 changes: 38 additions & 0 deletions src/ChkpDump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
struct ChkpDump
steps::Int64
period::Int64
filename::String
end

ChkpDump(steps, ::Val{false}, period = 1, filename = "chkp") = nothing

function ChkpDump(steps, ::Val{true}, period = 1, filename = "chkp")
return ChkpDump(steps, period, filename)
end

dump_prim(::Nothing, _, _) = nothing

function dump_prim(chkp::ChkpDump, step, primal)
if (step - 1) % chkp.period == 0
blob = serialize(primal)
open("prim_$(chkp.filename)_$step.chkp", "w") do file
write(file, blob)
end
end
end

dump_adj(::Nothing, _, _) = nothing

function dump_adj(chkp::ChkpDump, step, adjoint)
@show step
@show chkp.period
@show step % chkp.period
if (step - 1) % chkp.period == 0
blob = serialize(adjoint)
open("adj_$(chkp.filename)_$step.chkp", "w") do file
write(file, blob)
end
end
end

read_chkp_file(filename) = deserialize(read(filename))
47 changes: 13 additions & 34 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT}
frestore::Union{Function,Nothing}
storage::AbstractStorage
gc::Bool
write_checkpoints::Bool
write_checkpoints_filename::String
write_checkpoints_period::Int
chkp_dump::Union{Nothing,ChkpDump}
end

function Revolve{MT}(
Expand All @@ -41,8 +39,8 @@ function Revolve{MT}(
verbose::Int = 0,
gc::Bool = true,
write_checkpoints::Bool = false,
write_checkpoints_filename::String = "chkp.h5",
write_checkpoints_period::Int = 1,
write_checkpoints_filename::String = "chkp",
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
Expand Down Expand Up @@ -102,9 +100,12 @@ function Revolve{MT}(
frestore,
storage,
gc,
write_checkpoints,
write_checkpoints_filename,
write_checkpoints_period,
ChkpDump(
steps,
Val(write_checkpoints),
write_checkpoints_period,
write_checkpoints_filename,
),
)

if verbose > 0
Expand Down Expand Up @@ -437,16 +438,6 @@ function rev_checkpoint_struct_for(
if !alg.gc
GC.enable(false)
end
if alg.write_checkpoints
prim_output = HDF5Storage{MT}(
alg.steps;
filename = "primal_$(alg.write_checkpoints_filename).h5",
)
adj_output = HDF5Storage{MT}(
alg.steps;
filename = "adjoint_$(alg.write_checkpoints_filename).h5",
)
end
step = alg.steps
while true
next_action = next_action!(alg)
Expand All @@ -461,29 +452,21 @@ function rev_checkpoint_struct_for(
elseif (next_action.actionflag == Checkpointing.firstuturn)
body(model)
model_final = deepcopy(model)
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
prim_output[step] = model_final
end
dump_prim(alg.chkp_dump, step, model_final)
if alg.verbose > 0
@info "Revolve: First Uturn"
@info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))"
end
Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel))
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
adj_output[step] = shadowmodel
end
dump_adj(alg.chkp_dump, step, shadowmodel)
step -= 1
if !alg.gc
GC.gc()
end
elseif (next_action.actionflag == Checkpointing.uturn)
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
prim_output[step] = model
end
dump_prim(alg.chkp_dump, step, model)
Enzyme.autodiff(Reverse, Const(body), Duplicated(model, shadowmodel))
if alg.write_checkpoints && step % alg.write_checkpoints_period == 1
adj_output[step] = shadowmodel
end
dump_adj(alg.chkp_dump, step, shadowmodel)
step -= 1
if !alg.gc
GC.gc()
Expand All @@ -505,9 +488,5 @@ function rev_checkpoint_struct_for(
if !alg.gc
GC.enable(true)
end
if alg.write_checkpoints
close(prim_output.fid)
close(adj_output.fid)
end
return model_final
return nothing
end
6 changes: 1 addition & 5 deletions test/output_chkp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ dx = ChkpOut([0.0, 0.0, 0.0])

g = autodiff(Enzyme.Reverse, loops, Active, Duplicated(x, dx), Const(revolve), Const(iters))

fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r")
chkp = Checkpointing.deserialize(read("adj_chkp_1.chkp"))
# List all checkpoints
saved_chkp = sort(parse.(Int, (keys(fid))))
println("Checkpoints saved: $saved_chkp")
chkp = Checkpointing.deserialize(read(fid["1"]))
@test isa(chkp, ChkpOut)
@test all(dx .== chkp.x)
close(fid)
Loading