Skip to content

Commit

Permalink
Merge pull request #14 from Argonne-National-Laboratory/ms/storage
Browse files Browse the repository at this point in the history
Add HDF5 support through AbstractStorage
  • Loading branch information
sriharikrishna authored Jul 13, 2022
2 parents a180e7c + 00cb655 commit 5cbaaa9
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 27 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.4.0"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
ChainRulesCore = "1.0"
Enzyme = "0.9"
HDF5 = "0.16"
julia = "1.7"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ heat.Tnext[end] = 0
# Number of available snapshots
snaps = 4
verbose = 0
revolve = Revolve(tsteps, snaps; verbose=verbose)
revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose)

# Compute gradient
g = Zygote.gradient(sumheat, heat, revolve)
Expand Down
12 changes: 5 additions & 7 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ function advance(heat)
end
function sumheat(heat::Heat)
# AD: Create shadow copy for derivatives
shadowheat = Heat(zeros(n), zeros(n), 0, 0.0, 0)
@checkpoint_struct revolve heat shadowheat for i in 1:tsteps
function sumheat(heat::Heat, chkpt::Scheme)
@checkpoint_struct revolve heat for i in 1:tsteps
heat.Tlast .= heat.Tnext
advance(heat)
end
Expand All @@ -56,10 +54,10 @@ heat.Tnext[end] = 0
# Number of available snapshots
snaps = 4
verbose = 0
revolve = Revolve(tsteps, snaps; verbose=verbose)
revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose)
# Compute gradient
g = Zygote.gradient(sumheat,heat)
g = Zygote.gradient(sumheat, heat, revolve)
```

Plot function values:
Expand All @@ -69,4 +67,4 @@ plot(heat.Tnext)
Plot gradient with respect to sum(T):
```@example heat
plot(g[1].Tnext[2:end-1])
```
```
2 changes: 1 addition & 1 deletion examples/printaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Checkpointing

function main(steps, checkpoints; verbose=0)
store = function f() end
revolve = Revolve(steps, checkpoints, store, store; verbose=verbose)
revolve = Revolve{Nothing}(steps, checkpoints, store, store; verbose=verbose)
guessed_checkpoints = guess(revolve)
@show guessed_checkpoints
println("Revolve suggests : $guessed_checkpoints checkpoints for a factor of $(factor(revolve, steps,guessed_checkpoints))")
Expand Down
22 changes: 22 additions & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module Checkpointing
using ChainRulesCore
using LinearAlgebra
using Enzyme
using Serialization
using HDF5

"""
Scheme
Expand Down Expand Up @@ -60,6 +62,26 @@ end

export Scheme, AbstractADTool, jacobian, @checkpoint, @checkpoint_struct, checkpoint_struct

function serialize(x)
s = IOBuffer()
Serialization.serialize(s, x)
take!(s)
end

function deserialize(x)
s = IOBuffer(x)
Serialization.deserialize(s)
end

export serialize, deserialize

abstract type AbstractStorage end

include("Storage/ArrayStorage.jl")
include("Storage/HDF5Storage.jl")

export AbstractStorage, ArrayStorage, HDF5Storage

include("deprecated.jl")
include("Schemes/Revolve.jl")
include("Schemes/Periodic.jl")
Expand Down
12 changes: 7 additions & 5 deletions src/Schemes/Periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@
Periodic checkpointing scheme.
"""
mutable struct Periodic <: Scheme
mutable struct Periodic{MT} <: Scheme where {MT}
steps::Int
acp::Int
period::Int
verbose::Int
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
end

function Periodic(
function Periodic{MT}(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand All @@ -36,7 +38,7 @@ function Periodic(
acp = checkpoints
period = div(steps, checkpoints)

periodic = Periodic(steps, acp, period, verbose, fstore, frestore)
periodic = Periodic{MT}(steps, acp, period, verbose, fstore, frestore, storage)

forwardcount(periodic)
return periodic
Expand All @@ -59,7 +61,7 @@ function checkpoint_struct(body::Function,
) where{MT}
model = deepcopy(model_input)
model_final = []
model_check_outer = Array{MT}(undef, alg.acp)
model_check_outer = alg.storage
model_check_inner = Array{MT}(undef, alg.period)
check = 0
for i = 1:alg.acp
Expand Down
16 changes: 11 additions & 5 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ A minor extension is the optional `bundle` parameter that allows to treat as ma
iterations in one tape/adjoint sweep. If `bundle` is 1, the default, then the behavior is that of Alg. 799.
"""
mutable struct Revolve <: Scheme
mutable struct Revolve{MT} <: Scheme where {MT}
steps::Int
bundle::Int
tail::Int
Expand All @@ -23,17 +23,19 @@ mutable struct Revolve <: Scheme
verbose::Int
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
end

function Revolve(
function Revolve{MT}(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand Down Expand Up @@ -69,7 +71,11 @@ function Revolve(
firstuturned = false
stepof = Vector{Int}(undef, acp+1)

revolve = Revolve(steps, bundle, tail, acp, cstart, cend, numfwd, numinv, numstore, rwcp, prevcend, firstuturned, stepof, verbose, fstore, frestore)
revolve = Revolve{MT}(
steps, bundle, tail, acp, cstart, cend, numfwd,
numinv, numstore, rwcp, prevcend, firstuturned,
stepof, verbose, fstore, frestore, storage
)

if verbose > 0
predfwdcnt = forwardcount(revolve)
Expand Down Expand Up @@ -366,7 +372,7 @@ function checkpoint_struct(body::Function,
model = deepcopy(model_input)
storemap = Dict{Int32,Int32}()
check = 0
model_check = Array{MT}(undef, alg.acp)
model_check = alg.storage
model_final = []
while true
next_action = next_action!(alg)
Expand Down
17 changes: 17 additions & 0 deletions src/Storage/ArrayStorage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct ArrayStorage{MT} <: AbstractStorage where {MT}
_storage::Array{MT}
end

function ArrayStorage{MT}(acp::Int) where {MT}
storage = Array{MT}(undef, acp)
return ArrayStorage(storage)
end

Base.getindex(storage::ArrayStorage{MT}, i) where {MT} = storage._storage[i]

function Base.setindex!(storage::ArrayStorage{MT}, v, i) where {MT}
storage._storage[i] = v
end

Base.ndims(::Type{ArrayStorage{MT}}) where {MT} = 1
Base.size(storage::ArrayStorage{MT}) where {MT} = size(storage._storage)
31 changes: 31 additions & 0 deletions src/Storage/HDF5Storage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
mutable struct HDF5Storage{MT} <: AbstractStorage where {MT}
fid::HDF5.File
filename::String
acp::Int64
end

function HDF5Storage{MT}(acp::Int; filename=tempname()) where {MT}
fid = h5open(filename, "w")
storage = HDF5Storage{MT}(fid, filename, acp)
function _finalizer(storage::HDF5Storage{MT})
close(storage.fid)
return storage
end
finalizer(_finalizer, storage)
return storage
end

function Base.getindex(storage::HDF5Storage{MT}, i)::MT where {MT}
@assert i >= 1 && i <= storage.acp
blob = read(storage.fid["$i"])
return deserialize(blob)
end

function Base.setindex!(storage::HDF5Storage{MT}, v::MT, i) where {MT}
@assert i >= 1 && i <= storage.acp
if haskey(storage.fid, "$i")
delete_object(storage.fid, "$i")
end
blob = serialize(v)
storage.fid["$i"] = blob
end
40 changes: 33 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ include("../examples/adtools.jl")
t = F_C[3,i]
return F_H, t
end
revolve = Revolve(steps, snaps, store, restore; verbose=info)
revolve = Revolve{Nothing}(steps, snaps, store, restore; verbose=info)
F_opt, F_final, L_opt, L = optcontrol(revolve, steps, adtool)
@test isapprox(F_opt, F_final, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -77,7 +77,7 @@ include("../examples/adtools.jl")
t = F_C[3,i]
return F_H, t
end
periodic = Periodic(steps, snaps, store, restore; verbose=info)
periodic = Periodic{Nothing}(steps, snaps, store, restore; verbose=info)
F_opt, F_final, L_opt, L = optcontrol(periodic, steps, adtool)
@test isapprox(F_opt, F_final, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -92,7 +92,7 @@ include("../examples/adtools.jl")
snaps = 3
info = 0

revolve = Revolve(steps, snaps; verbose=info)
revolve = Revolve{Model}(steps, snaps; verbose=info)
F, L, F_opt, L_opt = muoptcontrol(revolve, steps)
@test isapprox(F_opt, F, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -103,20 +103,20 @@ include("../examples/adtools.jl")
snaps = 4
info = 0

periodic = Periodic(steps, snaps; verbose=info)
periodic = Periodic{Model}(steps, snaps; verbose=info)
F, L, F_opt, L_opt = muoptcontrol(periodic, steps)
@test isapprox(F_opt, F, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
end
end
@testset "Test heat" begin
@testset "Test heat example" begin
include("../examples/heat.jl")
@testset "Testing Revolve..." begin
steps = 500
snaps = 4
info = 0

revolve = Revolve(steps, snaps; verbose=info)
revolve = Revolve{Heat}(steps, snaps; verbose=info)
T, dT = heat(revolve, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
Expand All @@ -128,7 +128,33 @@ include("../examples/adtools.jl")
snaps = 4
info = 0

periodic = Periodic(steps, snaps; verbose=info)
periodic = Periodic{Heat}(steps, snaps; verbose=info)
T, dT = heat(periodic, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
@test isapprox(norm(dT), 6.970279349365908, atol=1e-11)
end
end
@testset "Test HDF5 storage using heat example" begin
include("../examples/heat.jl")
@testset "Testing Revolve..." begin
steps = 500
snaps = 4
info = 0

revolve = Revolve{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info)
T, dT = heat(revolve, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
@test isapprox(norm(dT), 6.970279349365908, atol=1e-11)
end

@testset "Testing Periodic..." begin
steps = 500
snaps = 4
info = 0

periodic = Periodic{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info)
T, dT = heat(periodic, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
Expand Down

0 comments on commit 5cbaaa9

Please sign in to comment.