Skip to content

Commit

Permalink
Make GC optional and add more verbose output
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jun 9, 2023
1 parent a7e73a9 commit 0f225ef
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.8.1"
version = "0.8.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
27 changes: 22 additions & 5 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mutable struct Revolve{MT} <: Scheme where {MT}
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
gc::Bool
end

function Revolve{MT}(
Expand All @@ -34,14 +35,19 @@ function Revolve{MT}(
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
verbose::Int = 0,
gc::Bool = true
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
anActionInstance.iteration = 0
anActionInstance.cpNum = 0
end
if verbose > 0
@info "Revolve: Number of checkpoints: $checkpoints"
@info "Revolve: Number of steps: $steps"
end
!isa(bundle_, Nothing) ? bundle = bundle_ : bundle = 1
if bundle < 1 || bundle > steps
error("Revolve: bundle parameter out of range [1,steps]")
Expand Down Expand Up @@ -74,7 +80,7 @@ function Revolve{MT}(
revolve = Revolve{MT}(
steps, bundle, tail, acp, cstart, cend, numfwd,
numinv, numstore, rwcp, prevcend, firstuturned,
stepof, verbose, fstore, frestore, storage
stepof, verbose, fstore, frestore, storage, gc
)

if verbose > 0
Expand Down Expand Up @@ -393,11 +399,14 @@ function rev_checkpoint_struct_for(
range
) where {MT}
model = deepcopy(model_input)
@info "Size per checkpoint: $(Base.format_bytes(Base.summarysize(model)))"
storemap = Dict{Int32,Int32}()
check = 0
model_check = alg.storage
model_final = []
GC.enable(false)
if !alg.gc
GC.enable(false)
end
while true
next_action = next_action!(alg)
if (next_action.actionflag == Checkpointing.store)
Expand All @@ -411,11 +420,19 @@ function rev_checkpoint_struct_for(
elseif (next_action.actionflag == Checkpointing.firstuturn)
body(model)
model_final = deepcopy(model)
if alg.verbose > 0
@info "Revolve: First Uturn"
@info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))"
end
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
GC.gc()
if !alg.gc
GC.gc()
end
elseif (next_action.actionflag == Checkpointing.uturn)
Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel))
GC.gc()
if !alg.gc
GC.gc()
end
if haskey(storemap,next_action.iteration-1-1)
delete!(storemap,next_action.iteration-1-1)
check=check-1
Expand Down

0 comments on commit 0f225ef

Please sign in to comment.