From 6fdf9e7cc81c2888b5123829ce2e194a2e99e5b0 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Wed, 14 Sep 2022 09:42:32 -0500 Subject: [PATCH] Integrate online checkpointing with struct checkpointing --- examples/heat.jl | 39 ++++++++++-- src/Checkpointing.jl | 62 ++++++++++++++++--- src/Schemes/Online_r2.jl | 130 +++++++++++++++++++++++++++++++-------- src/Schemes/Periodic.jl | 12 ++-- src/Schemes/Revolve.jl | 12 ++-- test/runtests.jl | 38 +++++++++--- 6 files changed, 232 insertions(+), 61 deletions(-) diff --git a/examples/heat.jl b/examples/heat.jl index bdc14aa..15cf22a 100644 --- a/examples/heat.jl +++ b/examples/heat.jl @@ -22,7 +22,7 @@ function advance(heat) end -function sumheat(heat::Heat, chkpscheme::Scheme) +function sumheat_for(heat::Heat, chkpscheme::Scheme, tsteps::Int64) # AD: Create shadow copy for derivatives @checkpoint_struct chkpscheme heat for i in 1:tsteps heat.Tlast .= heat.Tnext @@ -31,14 +31,25 @@ function sumheat(heat::Heat, chkpscheme::Scheme) return reduce(+, heat.Tnext) end -function heat(scheme::Scheme, tsteps::Int) +function sumheat_while(heat::Heat, chkpscheme::Scheme, tsteps::Int64) + # AD: Create shadow copy for derivatives + heat.tsteps = 1 + @checkpoint_struct chkpscheme heat while heat.tsteps <= tsteps + heat.Tlast .= heat.Tnext + advance(heat) + heat.tsteps += 1 + end + return reduce(+, heat.Tnext) +end + +function heat_for(scheme::Scheme, tsteps::Int) n = 100 Δx=0.1 Δt=0.001 # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 λ = 0.5 - # Create object from struct + # Create object from struct. tsteps is not needed for a for-loop heat = Heat(zeros(n), zeros(n), n, λ, tsteps) # Boundary conditions @@ -46,7 +57,27 @@ function heat(scheme::Scheme, tsteps::Int) heat.Tnext[end] = 0 # Compute gradient - g = Zygote.gradient(sumheat, heat, scheme) + g = Zygote.gradient(sumheat_for, heat, scheme, tsteps) + + return heat.Tnext, g[1].Tnext[2:end-1] +end + +function heat_while(scheme::Scheme, tsteps::Int) + n = 100 + Δx=0.1 + Δt=0.001 + # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 + λ = 0.5 + + # Create object from struct + heat = Heat(zeros(n), zeros(n), n, λ, 1) + + # Boundary conditions + heat.Tnext[1] = 20.0 + heat.Tnext[end] = 0 + + # Compute gradient + g = Zygote.gradient(sumheat_while, heat, scheme, tsteps) return heat.Tnext, g[1].Tnext[2:end-1] end diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index bd915c0..799c8f3 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -61,7 +61,8 @@ function jacobian(tobedifferentiated, F_H, ::AbstractADTool) error("No AD tool interface implemented") end -export Scheme, AbstractADTool, jacobian, @checkpoint, @checkpoint_struct, checkpoint_struct +export Scheme, AbstractADTool, jacobian +export @checkpoint, @checkpoint_struct, checkpoint_struct_for, checkpoint_struct_while function serialize(x) s = IOBuffer() @@ -115,22 +116,43 @@ function create_tangent(shadowmodel::MT) where {MT} end function ChainRulesCore.rrule( - ::typeof(Checkpointing.checkpoint_struct), + ::typeof(Checkpointing.checkpoint_struct_for), body::Function, alg::Scheme, model::MT, shadowmodel::MT, + range::Function ) where {MT} model_input = deepcopy(model) - # shadowmodel = deepcopy(model) for i in 1:alg.steps body(model) end function checkpoint_struct_pullback(dmodel) copyto!(shadowmodel, dmodel) - model = checkpoint_struct(body, alg, model_input, shadowmodel) + model = checkpoint_struct_for(body, alg, model_input, shadowmodel, range) dshadowmodel = create_tangent(shadowmodel) - return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent() + return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent(), NoTangent() + end + return model, checkpoint_struct_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Checkpointing.checkpoint_struct_while), + body::Function, + alg::Scheme, + model::MT, + shadowmodel::MT, + condition::Function +) where {MT} + model_input = deepcopy(model) + while condition(model) + body(model) + end + function checkpoint_struct_pullback(dmodel) + copyto!(shadowmodel, dmodel) + model = checkpoint_struct_while(body, alg, model_input, shadowmodel, condition) + dshadowmodel = create_tangent(shadowmodel) + return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent(), NoTangent() end return model, checkpoint_struct_pullback end @@ -149,11 +171,28 @@ adjoints and is created here. It is supposed to be initialized by ChainRules. """ macro checkpoint_struct(alg, model, loop) - ex = quote - shadowmodel = deepcopy($model) - $model = checkpoint_struct($alg, $model, shadowmodel) do $model - $(loop.args[2]) + if loop.head == :for + ex = quote + shadowmodel = deepcopy($model) + function range() + $(loop.args[1]) + end + $model = checkpoint_struct_for($alg, $model, shadowmodel, range) do $model + $(loop.args[2]) + end end + elseif loop.head == :while + ex = quote + shadowmodel = deepcopy($model) + function condition($model) + $(loop.args[1]) + end + $model = checkpoint_struct_while($alg, $model, shadowmodel, condition) do $model + $(loop.args[2]) + end + end + else + error("Checkpointing.jl: Unknown loop construct.") end esc(ex) end @@ -173,7 +212,10 @@ are seeded and retrieved. """ macro checkpoint_struct(alg, model, shadowmodel, loop) ex = quote - $model = checkpoint_struct($alg, $model, $shadowmodel) do $model + function range() + $(loop.args[1]) + end + $model = checkpoint_struct_for($alg, $model, $shadowmodel, range) do $model $(loop.args[2]) end end diff --git a/src/Schemes/Online_r2.jl b/src/Schemes/Online_r2.jl index 8093764..e1fdde2 100644 --- a/src/Schemes/Online_r2.jl +++ b/src/Schemes/Online_r2.jl @@ -22,19 +22,21 @@ mutable struct Online_r2{MT} <: Scheme where {MT} incr::Int offset::Int t::Int - verbose::Bool + verbose::Int fstore::Union{Function,Nothing} frestore::Union{Function,Nothing} ch::Vector{Int} ord_ch::Vector{Int} num_rep::Vector{Int} - revolve::Revolve + revolve::Revolve{MT} + storage::AbstractStorage end function Online_r2{MT}( checkpoints::Int, fstore::Union{Function,Nothing} = nothing, - frestore::Union{Function,Nothing} = nothing, + frestore::Union{Function,Nothing} = nothing; + storage::AbstractStorage = ArrayStorage{MT}(checkpoints), anActionInstance::Union{Nothing,Action} = nothing, verbose::Int = 0 ) where {MT} @@ -67,11 +69,10 @@ function Online_r2{MT}( ord_ch[i] = -1 num_rep[i] = -1 end - verbose = false revolve = Revolve{MT}(typemax(Int64), acp, fstore, frestore; verbose=3) online_r2 = Online_r2{MT}(check, capo, acp, numfwd, numcmd, numstore, oldcapo, ind, oldind, iter, incr, offset, t, - verbose, fstore, frestore, ch, ord_ch, num_rep, revolve) + verbose, fstore, frestore, ch, ord_ch, num_rep, revolve, storage) return online_r2 end @@ -113,7 +114,7 @@ end function next_action!(online::Online_r2)::Action # Default values for next action actionflag = none - if online.verbose + if online.verbose > 0 if(online.check !=-1) @info(online.check+1, online.ch[online.check+1], online.capo) for i in 1:online.acp @@ -127,18 +128,18 @@ function next_action!(online::Online_r2)::Action end end online.numcmd+=1 - #We use this logic because the C++ version uses short circuiting + # We use this logic because the C++ version uses short circuiting cond2 = false if online.check != -1 cond2 = online.ch[online.check+1] != online.capo end online.oldcapo = online.capo if ((online.check == -1) || ( cond2 && (online.capo <= online.acp-1))) - #condition for takeshot for r=1 + # condition for takeshot for r=1 # (If no checkpoint has been taken before OR # If a store has not just occurred AND the iteration count is # less than the total number of checkpoints) - if online.verbose + if online.verbose > 0 @info("condition for takeshot for r=1") end online.check += 1 @@ -168,20 +169,20 @@ function next_action!(online::Online_r2)::Action online.numstore+=1 return Action(store, online.capo-1, -1, online.check) elseif (online.capo < online.acp-1) - #condition for advance for r=1 + # condition for advance for r=1 # (the iteraton is less that the total number of checkpoints) - if online.verbose + if online.verbose > 0 @info("condition for advance for r=1") end online.capo = online.oldcapo+1 online.numfwd+=1 return Action(forward, online.capo, online.oldcapo, -1) else - #Online_r2-Checkpointing for r=2 + # Online_r2-Checkpointing for r=2 if (online.ch[online.check+1] == online.capo) # condition for advance for r=2 # (checkpoint has just occurred) - if online.verbose + if online.verbose > 0 @info("Online_r2-condition for advance for r=2 online.acp=", online.acp) end if (online.acp == 1) @@ -209,7 +210,7 @@ function next_action!(online::Online_r2)::Action actionflag = forward return Action(forward, online.capo, online.oldcapo, -1) else - if online.verbose + if online.verbose > 0 @info("Online_r2-condition for advance for r=2 online.acp-1=", online.acp-1," online.capo= ", online.capo) end if (online.capo == online.acp-1) @@ -231,27 +232,27 @@ function next_action!(online::Online_r2)::Action end return Action(forward, online.capo, online.oldcapo, -1) end - if (online.verbose) + if online.verbose > 0 @info(" iter ", iter, "incr ", incr) end error(" not implemented yet") return Action(done, online.capo, online.oldcapo, -1) end else - #takeshot for r=2 - if (online.verbose) + # takeshot for r=2 + if online.verbose > 0 @info("Online_r2-condition for takeshot for r=2 online.acp =", online.acp) end if (online.acp == 2) online.ch[1+1] = online.capo online.incr+=1 - #Increase the number of takeshots and the corresponding checkpoint + # Increase the number of takeshots and the corresponding checkpoint online.numstore+=1 return Action(store, online.capo-1, -1, 1+1) elseif (online.acp == 3) online.ch[online.ind+1] = online.capo online.check = online.ind - if (online.verbose) + if online.verbose > 0 @info(" iter ", online.iter, " online.num_rep[1] ", online.num_rep[1+1]) end if (online.iter == online.num_rep[1+1]) @@ -262,11 +263,11 @@ function next_action!(online::Online_r2)::Action online.ind = 2 - online.num_rep[1+1]%2 online.incr=1 end - #Increase the number of takeshots and the corresponding checkpoint + # Increase the number of takeshots and the corresponding checkpoint online.numstore+=1 return Action(store, online.capo-1, -1, online.check) else - if (online.verbose) + if online.verbose > 0 @info(" online.capo ", online.capo, " online.acp ", online.acp) end if (online.capo < online.acp+2) @@ -275,7 +276,7 @@ function next_action!(online::Online_r2)::Action if (online.capo == online.acp+1) online.oldind = online.ord_ch[online.acp-1+1] online.ind = online.ch[online.ord_ch[online.acp-1+1]+1] - if (online.verbose) + if online.verbose > 0 @info(" oldind ", online.oldind, " ind ", online.ind) end for k=online.acp:-1:3 @@ -286,7 +287,7 @@ function next_action!(online::Online_r2)::Action online.ch[online.ord_ch[1+1]+1] = online.ind online.incr = 2 online.ind = 2 - if (online.verbose) + if online.verbose > 0 @info(" ind ", online.ind, " incr ", online.incr, " iter ", online.iter) for j=1:online.acp @info(" j ", j, " ord_ch ", online.ord_ch[j], " ch ", online.ch[online.ord_ch[j]+1], " rep ", online.num_rep[online.ord_ch[j]+1]) @@ -299,7 +300,7 @@ function next_action!(online::Online_r2)::Action end if (online.t == 0) - if (online.verbose) + if online.verbose > 0 @info(" online.ind ", online.ind, " online.incr ", online.incr, " iter ", online.iter, " offset ", online.offset) end if (online.iter == online.offset) @@ -309,7 +310,7 @@ function next_action!(online::Online_r2)::Action online.ch[online.ord_ch[online.acp-1+1]+1] = online.capo online.oldind = online.ord_ch[online.acp-1+1] online.ind = online.ch[online.ord_ch[online.acp-1+1]+1] - if (online.verbose) + if online.verbose > 0 @info(" oldind " , online.oldind , " ind " , online.ind) end for k=online.acp-1:-1:online.incr+1 @@ -320,7 +321,7 @@ function next_action!(online::Online_r2)::Action online.ch[online.ord_ch[online.incr+1]+1] = online.ind online.incr+=1 online.ind=online.incr - if (online.verbose) + if online.verbose > 0 @info(" ind ", online.ind, " incr ", online.incr, " iter ", online.iter) for j=1:online.acp @info(" j ", j, " ord_ch ", online.ord_ch[j], " ch ", online.ch[online.ord_ch[j]+1], " rep ", online.num_rep[online.ord_ch[j]+1]) @@ -331,7 +332,7 @@ function next_action!(online::Online_r2)::Action online.check = online.ord_ch[online.ind+1] online.iter+=1 online.ind+=1 - if (online.verbose) + if online.verbose > 0 @info(" xx ind ", online.ind, " incr ", online.incr, " iter ", online.iter) end end @@ -348,3 +349,78 @@ function next_action!(online::Online_r2)::Action @info("Online_r2 is optimal over the range [0,(numcheckpoints+2)*(numcheckpoints+1)/2]. Online_r3 needs to be implemented") return Action(error, online.capo, online.oldcapo, -1) end + +function checkpoint_struct_while( + body::Function, + alg::Online_r2, + model_input::MT, + shadowmodel::MT, + condition::Function +) where {MT} + model = deepcopy(model_input) + model_check = alg.storage + model_final = [] + freeindices = Stack{Int32}() + storemapinv = Dict{Int32,Int32}() + storemap = Dict{Int32,Int32}() + check = 0 + oldcapo=0 + onlinesteps=0 + go = true + while go + next_action = next_action!(alg) + if (next_action.actionflag == Checkpointing.store) + check=next_action.cpnum+1 + storemapinv[check]=next_action.iteration + model_check[check] = deepcopy(model) + elseif (next_action.actionflag == Checkpointing.forward) + for j= oldcapo:(next_action.iteration-1) + if go + body(model) + go = condition(model) + end + onlinesteps=onlinesteps+1 + end + oldcapo=next_action.iteration + else + @error("Unexpected action in online phase: ", next_action.actionflag) + end + end + for (key, value) in storemapinv + storemap[value]=key + end + + # Switch to offline revolve now. + update_revolve(alg, onlinesteps+1) + while true + next_action = next_action!(alg.revolve) + if (next_action.actionflag == Checkpointing.store) + check=pop!(freeindices) + storemap[next_action.iteration-1]=check + model_check[check] = deepcopy(model) + elseif (next_action.actionflag == Checkpointing.forward) + for j= next_action.startiteration:(next_action.iteration-1) + body(model) + end + elseif (next_action.actionflag == Checkpointing.firstuturn) + # Commented out lines are weird + # body(model) + model_final = deepcopy(model) + # Enzyme.autodiff(body, Duplicated(model,shadowmodel)) + elseif (next_action.actionflag == Checkpointing.uturn) + Enzyme.autodiff(body, Duplicated(model,shadowmodel)) + if haskey(storemap,next_action.iteration-1-1) + push!(freeindices, storemap[next_action.iteration-1-1]) + delete!(storemap,next_action.iteration-1-1) + end + elseif (next_action.actionflag == Checkpointing.restore) + model = deepcopy(model_check[storemap[next_action.iteration-1]]) + elseif next_action.actionflag == Checkpointing.done + if haskey(storemap,next_action.iteration-1-1) + delete!(storemap,next_action.iteration-1-1) + end + break + end + end + return model_final +end diff --git a/src/Schemes/Periodic.jl b/src/Schemes/Periodic.jl index 7246e47..685ce14 100644 --- a/src/Schemes/Periodic.jl +++ b/src/Schemes/Periodic.jl @@ -54,11 +54,13 @@ function forwardcount(periodic::Periodic) end end -function checkpoint_struct(body::Function, - alg::Periodic, - model_input::MT, - shadowmodel::MT - ) where{MT} +function checkpoint_struct_for( + body::Function, + alg::Periodic, + model_input::MT, + shadowmodel::MT, + range::Function +) where{MT} model = deepcopy(model_input) model_final = [] model_check_outer = alg.storage diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index 2777e6c..635b1e9 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -364,11 +364,13 @@ function forwardcount(revolve::Revolve) return ret end -function checkpoint_struct(body::Function, - alg::Revolve, - model_input::MT, - shadowmodel::MT - ) where {MT} +function checkpoint_struct_for( + body::Function, + alg::Revolve, + model_input::MT, + shadowmodel::MT, + range +) where {MT} model = deepcopy(model_input) storemap = Dict{Int32,Int32}() check = 0 diff --git a/test/runtests.jl b/test/runtests.jl index e3666d6..95f5504 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,7 +65,7 @@ include("../examples/adtools.jl") include("../examples/deprecated/optcontrol.jl") global steps = 100 global snaps = 4 - global info = 1 + global info = 0 function store(F_H, F_C,t, i) F_C[1,i] = F_H[1] @@ -88,13 +88,9 @@ include("../examples/adtools.jl") @testset "Testing Online_r2..." begin include("../examples/optcontrolwhile.jl") # Enzyme segfaults if the garbage collector is enabled - if isa(adtool, EnzymeADTool) - GC.gc() - GC.enable(false) - end global steps = 100 global snaps = 20 - global info = 3 + global info = 0 function store(F_H, F_C,t, i) F_C[1,i] = F_H[1] @@ -148,7 +144,7 @@ include("../examples/adtools.jl") info = 0 revolve = Revolve{Heat}(steps, snaps; verbose=info) - T, dT = heat(revolve, steps) + T, dT = heat_for(revolve, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11) @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) @@ -160,7 +156,18 @@ include("../examples/adtools.jl") info = 0 periodic = Periodic{Heat}(steps, snaps; verbose=info) - T, dT = heat(periodic, steps) + T, dT = heat_for(periodic, steps) + + @test isapprox(norm(T), 66.21987468492061, atol=1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + end + + @testset "Testing Online_r2..." begin + steps = 500 + snaps = 100 + info = 0 + online = Online_r2{Heat}(snaps; verbose=info) + T, dT = heat_while(online, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11) @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) @@ -174,7 +181,7 @@ include("../examples/adtools.jl") info = 0 revolve = Revolve{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) - T, dT = heat(revolve, steps) + T, dT = heat_for(revolve, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11) @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) @@ -186,7 +193,18 @@ include("../examples/adtools.jl") info = 0 periodic = Periodic{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) - T, dT = heat(periodic, steps) + T, dT = heat_for(periodic, steps) + + @test isapprox(norm(T), 66.21987468492061, atol=1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + end + + @testset "Testing Online_r2..." begin + steps = 500 + snaps = 100 + info = 0 + online = Online_r2{Heat}(snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) + T, dT = heat_while(online, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11) @test isapprox(norm(dT), 6.970279349365908, atol=1e-11)