diff --git a/Project.toml b/Project.toml index 9a253db..af5d9a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,19 @@ name = "Checkpointing" uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" authors = ["Michel Schanen ", "Sri Hari Krishna Narayanan "] -version = "0.4.0" +version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [compat] ChainRulesCore = "1.0" Enzyme = "0.9" +HDF5 = "0.16" julia = "1.7" [extras] diff --git a/README.md b/README.md index 87e93f0..79d5516 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ heat.Tnext[end] = 0 # Number of available snapshots snaps = 4 verbose = 0 -revolve = Revolve(tsteps, snaps; verbose=verbose) +revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose) # Compute gradient g = Zygote.gradient(sumheat, heat, revolve) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index fa3cf5f..6a572b3 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -27,10 +27,8 @@ function advance(heat) end -function sumheat(heat::Heat) - # AD: Create shadow copy for derivatives - shadowheat = Heat(zeros(n), zeros(n), 0, 0.0, 0) - @checkpoint_struct revolve heat shadowheat for i in 1:tsteps +function sumheat(heat::Heat, chkpt::Scheme) + @checkpoint_struct revolve heat for i in 1:tsteps heat.Tlast .= heat.Tnext advance(heat) end @@ -56,10 +54,10 @@ heat.Tnext[end] = 0 # Number of available snapshots snaps = 4 verbose = 0 -revolve = Revolve(tsteps, snaps; verbose=verbose) +revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose) # Compute gradient -g = Zygote.gradient(sumheat,heat) +g = Zygote.gradient(sumheat, heat, revolve) ``` Plot function values: @@ -69,4 +67,4 @@ plot(heat.Tnext) Plot gradient with respect to sum(T): ```@example heat plot(g[1].Tnext[2:end-1]) -``` \ No newline at end of file +``` diff --git a/examples/printaction.jl b/examples/printaction.jl index 84aea34..e668ea0 100644 --- a/examples/printaction.jl +++ b/examples/printaction.jl @@ -2,7 +2,7 @@ using Checkpointing function main(steps, checkpoints; verbose=0) store = function f() end - revolve = Revolve(steps, checkpoints, store, store; verbose=verbose) + revolve = Revolve{Nothing}(steps, checkpoints, store, store; verbose=verbose) guessed_checkpoints = guess(revolve) @show guessed_checkpoints println("Revolve suggests : $guessed_checkpoints checkpoints for a factor of $(factor(revolve, steps,guessed_checkpoints))") diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index 4204ec2..653cf5c 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -3,6 +3,8 @@ module Checkpointing using ChainRulesCore using LinearAlgebra using Enzyme +using Serialization +using HDF5 """ Scheme @@ -60,6 +62,26 @@ end export Scheme, AbstractADTool, jacobian, @checkpoint, @checkpoint_struct, checkpoint_struct +function serialize(x) + s = IOBuffer() + Serialization.serialize(s, x) + take!(s) +end + +function deserialize(x) + s = IOBuffer(x) + Serialization.deserialize(s) +end + +export serialize, deserialize + +abstract type AbstractStorage end + +include("Storage/ArrayStorage.jl") +include("Storage/HDF5Storage.jl") + +export AbstractStorage, ArrayStorage, HDF5Storage + include("deprecated.jl") include("Schemes/Revolve.jl") include("Schemes/Periodic.jl") diff --git a/src/Schemes/Periodic.jl b/src/Schemes/Periodic.jl index 55adbd8..7246e47 100644 --- a/src/Schemes/Periodic.jl +++ b/src/Schemes/Periodic.jl @@ -9,24 +9,26 @@ Periodic checkpointing scheme. """ -mutable struct Periodic <: Scheme +mutable struct Periodic{MT} <: Scheme where {MT} steps::Int acp::Int period::Int verbose::Int fstore::Union{Function,Nothing} frestore::Union{Function,Nothing} + storage::AbstractStorage end -function Periodic( +function Periodic{MT}( steps::Int, checkpoints::Int, fstore::Union{Function,Nothing} = nothing, frestore::Union{Function,Nothing} = nothing; + storage::AbstractStorage = ArrayStorage{MT}(checkpoints), anActionInstance::Union{Nothing,Action} = nothing, bundle_::Union{Nothing,Int} = nothing, verbose::Int = 0 -) +) where {MT} if !isa(anActionInstance, Nothing) # same as default init above anActionInstance.actionflag = 0 @@ -36,7 +38,7 @@ function Periodic( acp = checkpoints period = div(steps, checkpoints) - periodic = Periodic(steps, acp, period, verbose, fstore, frestore) + periodic = Periodic{MT}(steps, acp, period, verbose, fstore, frestore, storage) forwardcount(periodic) return periodic @@ -59,7 +61,7 @@ function checkpoint_struct(body::Function, ) where{MT} model = deepcopy(model_input) model_final = [] - model_check_outer = Array{MT}(undef, alg.acp) + model_check_outer = alg.storage model_check_inner = Array{MT}(undef, alg.period) check = 0 for i = 1:alg.acp diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index c612d1a..2777e6c 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -6,7 +6,7 @@ A minor extension is the optional `bundle` parameter that allows to treat as ma iterations in one tape/adjoint sweep. If `bundle` is 1, the default, then the behavior is that of Alg. 799. """ -mutable struct Revolve <: Scheme +mutable struct Revolve{MT} <: Scheme where {MT} steps::Int bundle::Int tail::Int @@ -23,17 +23,19 @@ mutable struct Revolve <: Scheme verbose::Int fstore::Union{Function,Nothing} frestore::Union{Function,Nothing} + storage::AbstractStorage end -function Revolve( +function Revolve{MT}( steps::Int, checkpoints::Int, fstore::Union{Function,Nothing} = nothing, frestore::Union{Function,Nothing} = nothing; + storage::AbstractStorage = ArrayStorage{MT}(checkpoints), anActionInstance::Union{Nothing,Action} = nothing, bundle_::Union{Nothing,Int} = nothing, verbose::Int = 0 -) +) where {MT} if !isa(anActionInstance, Nothing) # same as default init above anActionInstance.actionflag = 0 @@ -69,7 +71,11 @@ function Revolve( firstuturned = false stepof = Vector{Int}(undef, acp+1) - revolve = Revolve(steps, bundle, tail, acp, cstart, cend, numfwd, numinv, numstore, rwcp, prevcend, firstuturned, stepof, verbose, fstore, frestore) + revolve = Revolve{MT}( + steps, bundle, tail, acp, cstart, cend, numfwd, + numinv, numstore, rwcp, prevcend, firstuturned, + stepof, verbose, fstore, frestore, storage + ) if verbose > 0 predfwdcnt = forwardcount(revolve) @@ -366,7 +372,7 @@ function checkpoint_struct(body::Function, model = deepcopy(model_input) storemap = Dict{Int32,Int32}() check = 0 - model_check = Array{MT}(undef, alg.acp) + model_check = alg.storage model_final = [] while true next_action = next_action!(alg) diff --git a/src/Storage/ArrayStorage.jl b/src/Storage/ArrayStorage.jl new file mode 100644 index 0000000..c2b0c42 --- /dev/null +++ b/src/Storage/ArrayStorage.jl @@ -0,0 +1,17 @@ +struct ArrayStorage{MT} <: AbstractStorage where {MT} + _storage::Array{MT} +end + +function ArrayStorage{MT}(acp::Int) where {MT} + storage = Array{MT}(undef, acp) + return ArrayStorage(storage) +end + +Base.getindex(storage::ArrayStorage{MT}, i) where {MT} = storage._storage[i] + +function Base.setindex!(storage::ArrayStorage{MT}, v, i) where {MT} + storage._storage[i] = v +end + +Base.ndims(::Type{ArrayStorage{MT}}) where {MT} = 1 +Base.size(storage::ArrayStorage{MT}) where {MT} = size(storage._storage) diff --git a/src/Storage/HDF5Storage.jl b/src/Storage/HDF5Storage.jl new file mode 100644 index 0000000..80b86f4 --- /dev/null +++ b/src/Storage/HDF5Storage.jl @@ -0,0 +1,31 @@ +mutable struct HDF5Storage{MT} <: AbstractStorage where {MT} + fid::HDF5.File + filename::String + acp::Int64 +end + +function HDF5Storage{MT}(acp::Int; filename=tempname()) where {MT} + fid = h5open(filename, "w") + storage = HDF5Storage{MT}(fid, filename, acp) + function _finalizer(storage::HDF5Storage{MT}) + close(storage.fid) + return storage + end + finalizer(_finalizer, storage) + return storage +end + +function Base.getindex(storage::HDF5Storage{MT}, i)::MT where {MT} + @assert i >= 1 && i <= storage.acp + blob = read(storage.fid["$i"]) + return deserialize(blob) +end + +function Base.setindex!(storage::HDF5Storage{MT}, v::MT, i) where {MT} + @assert i >= 1 && i <= storage.acp + if haskey(storage.fid, "$i") + delete_object(storage.fid, "$i") + end + blob = serialize(v) + storage.fid["$i"] = blob +end diff --git a/test/runtests.jl b/test/runtests.jl index 72ad3f4..f0d72c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,7 +53,7 @@ include("../examples/adtools.jl") t = F_C[3,i] return F_H, t end - revolve = Revolve(steps, snaps, store, restore; verbose=info) + revolve = Revolve{Nothing}(steps, snaps, store, restore; verbose=info) F_opt, F_final, L_opt, L = optcontrol(revolve, steps, adtool) @test isapprox(F_opt, F_final, rtol=1e-4) @test isapprox(L_opt, L, rtol=1e-4) @@ -77,7 +77,7 @@ include("../examples/adtools.jl") t = F_C[3,i] return F_H, t end - periodic = Periodic(steps, snaps, store, restore; verbose=info) + periodic = Periodic{Nothing}(steps, snaps, store, restore; verbose=info) F_opt, F_final, L_opt, L = optcontrol(periodic, steps, adtool) @test isapprox(F_opt, F_final, rtol=1e-4) @test isapprox(L_opt, L, rtol=1e-4) @@ -92,7 +92,7 @@ include("../examples/adtools.jl") snaps = 3 info = 0 - revolve = Revolve(steps, snaps; verbose=info) + revolve = Revolve{Model}(steps, snaps; verbose=info) F, L, F_opt, L_opt = muoptcontrol(revolve, steps) @test isapprox(F_opt, F, rtol=1e-4) @test isapprox(L_opt, L, rtol=1e-4) @@ -103,20 +103,20 @@ include("../examples/adtools.jl") snaps = 4 info = 0 - periodic = Periodic(steps, snaps; verbose=info) + periodic = Periodic{Model}(steps, snaps; verbose=info) F, L, F_opt, L_opt = muoptcontrol(periodic, steps) @test isapprox(F_opt, F, rtol=1e-4) @test isapprox(L_opt, L, rtol=1e-4) end end - @testset "Test heat" begin + @testset "Test heat example" begin include("../examples/heat.jl") @testset "Testing Revolve..." begin steps = 500 snaps = 4 info = 0 - revolve = Revolve(steps, snaps; verbose=info) + revolve = Revolve{Heat}(steps, snaps; verbose=info) T, dT = heat(revolve, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11) @@ -128,7 +128,33 @@ include("../examples/adtools.jl") snaps = 4 info = 0 - periodic = Periodic(steps, snaps; verbose=info) + periodic = Periodic{Heat}(steps, snaps; verbose=info) + T, dT = heat(periodic, steps) + + @test isapprox(norm(T), 66.21987468492061, atol=1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + end + end + @testset "Test HDF5 storage using heat example" begin + include("../examples/heat.jl") + @testset "Testing Revolve..." begin + steps = 500 + snaps = 4 + info = 0 + + revolve = Revolve{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) + T, dT = heat(revolve, steps) + + @test isapprox(norm(T), 66.21987468492061, atol=1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + end + + @testset "Testing Periodic..." begin + steps = 500 + snaps = 4 + info = 0 + + periodic = Periodic{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) T, dT = heat(periodic, steps) @test isapprox(norm(T), 66.21987468492061, atol=1e-11)