Skip to content

Commit

Permalink
Add option write_checkpoints to disk
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jun 9, 2023
1 parent 0f225ef commit 3fb3969
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
34 changes: 29 additions & 5 deletions src/Schemes/Periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ mutable struct Periodic{MT} <: Scheme where {MT}
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
gc::Bool
write_checkpoints::Bool
end

function Periodic{MT}(
Expand All @@ -27,7 +29,9 @@ function Periodic{MT}(
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
verbose::Int = 0,
gc::Bool = true,
write_checkpoints::Bool = false
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
Expand All @@ -38,7 +42,11 @@ function Periodic{MT}(
acp = checkpoints
period = div(steps, checkpoints)

periodic = Periodic{MT}(steps, acp, period, verbose, fstore, frestore, storage)
periodic = Periodic{MT}(
steps, acp, period,verbose,
fstore, frestore, storage, gc,
write_checkpoints
)

forwardcount(periodic)
return periodic
Expand Down Expand Up @@ -66,7 +74,13 @@ function rev_checkpoint_struct_for(
model_check_outer = alg.storage
model_check_inner = Array{MT}(undef, alg.period)
check = 0
GC.enable(false)
if !alg.gc
GC.enable(false)
end
if alg.write_checkpoints
prim_output = HDF5Storage{MT}(alg.steps; filename="primal_chkp.h5")
adj_output = HDF5Storage{MT}(alg.steps; filename="adjoint_chkp.h5")
end
for i = 1:alg.acp
model_check_outer[i] = deepcopy(model)
for j= (i-1)*alg.period: (i)*alg.period-1
Expand All @@ -81,11 +95,21 @@ function rev_checkpoint_struct_for(
body(model)
end
for j= alg.period:-1:1
if alg.write_checkpoints
prim_output[check] = model
end
model = deepcopy(model_check_inner[j])
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
GC.gc()
if alg.write_checkpoints
adj_output[check] = shadowmodel
end
if !alg.gc
GC.gc()
end
end
end
GC.enable(true)
if !alg.gc
GC.enable(true)
end
return model_final
end
27 changes: 24 additions & 3 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mutable struct Revolve{MT} <: Scheme where {MT}
frestore::Union{Function,Nothing}
storage::AbstractStorage
gc::Bool
write_checkpoints::Bool
end

function Revolve{MT}(
Expand All @@ -36,7 +37,8 @@ function Revolve{MT}(
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0,
gc::Bool = true
gc::Bool = true,
write_checkpoints::Bool = false
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
Expand Down Expand Up @@ -80,7 +82,8 @@ function Revolve{MT}(
revolve = Revolve{MT}(
steps, bundle, tail, acp, cstart, cend, numfwd,
numinv, numstore, rwcp, prevcend, firstuturned,
stepof, verbose, fstore, frestore, storage, gc
stepof, verbose, fstore, frestore, storage, gc,
write_checkpoints
)

if verbose > 0
Expand Down Expand Up @@ -407,6 +410,10 @@ function rev_checkpoint_struct_for(
if !alg.gc
GC.enable(false)
end
if alg.write_checkpoints
prim_output = HDF5Storage{MT}(alg.steps; filename="primal_chkp.h5")
adj_output = HDF5Storage{MT}(alg.steps; filename="adjoint_chkp.h5")
end
while true
next_action = next_action!(alg)
if (next_action.actionflag == Checkpointing.store)
Expand All @@ -420,16 +427,28 @@ function rev_checkpoint_struct_for(
elseif (next_action.actionflag == Checkpointing.firstuturn)
body(model)
model_final = deepcopy(model)
if alg.write_checkpoints
prim_output[check] = model_final
end
if alg.verbose > 0
@info "Revolve: First Uturn"
@info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))"
end
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
if alg.write_checkpoints
adj_output[check] = shadowmodel
end
if !alg.gc
GC.gc()
end
elseif (next_action.actionflag == Checkpointing.uturn)
if alg.write_checkpoints
prim_output[check] = model
end
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
if alg.write_checkpoints
adj_output[check] = shadowmodel
end
if !alg.gc
GC.gc()
end
Expand All @@ -447,6 +466,8 @@ function rev_checkpoint_struct_for(
break
end
end
GC.enable(true)
if !alg.gc
GC.enable(true)
end
return model_final
end

0 comments on commit 3fb3969

Please sign in to comment.