Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HDF5 support through AbstractStorage #14

Merged
merged 4 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.4.0"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
ChainRulesCore = "1.0"
Enzyme = "0.9"
HDF5 = "0.16"
julia = "1.7"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ heat.Tnext[end] = 0
# Number of available snapshots
snaps = 4
verbose = 0
revolve = Revolve(tsteps, snaps; verbose=verbose)
revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose)

# Compute gradient
g = Zygote.gradient(sumheat, heat, revolve)
Expand Down
12 changes: 5 additions & 7 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ function advance(heat)
end


function sumheat(heat::Heat)
# AD: Create shadow copy for derivatives
shadowheat = Heat(zeros(n), zeros(n), 0, 0.0, 0)
@checkpoint_struct revolve heat shadowheat for i in 1:tsteps
function sumheat(heat::Heat, chkpt::Scheme)
@checkpoint_struct revolve heat for i in 1:tsteps
heat.Tlast .= heat.Tnext
advance(heat)
end
Expand All @@ -56,10 +54,10 @@ heat.Tnext[end] = 0
# Number of available snapshots
snaps = 4
verbose = 0
revolve = Revolve(tsteps, snaps; verbose=verbose)
revolve = Revolve{Heat}(tsteps, snaps; verbose=verbose)

# Compute gradient
g = Zygote.gradient(sumheat,heat)
g = Zygote.gradient(sumheat, heat, revolve)
```

Plot function values:
Expand All @@ -69,4 +67,4 @@ plot(heat.Tnext)
Plot gradient with respect to sum(T):
```@example heat
plot(g[1].Tnext[2:end-1])
```
```
2 changes: 1 addition & 1 deletion examples/printaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Checkpointing

function main(steps, checkpoints; verbose=0)
store = function f() end
revolve = Revolve(steps, checkpoints, store, store; verbose=verbose)
revolve = Revolve{Nothing}(steps, checkpoints, store, store; verbose=verbose)
guessed_checkpoints = guess(revolve)
@show guessed_checkpoints
println("Revolve suggests : $guessed_checkpoints checkpoints for a factor of $(factor(revolve, steps,guessed_checkpoints))")
Expand Down
22 changes: 22 additions & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module Checkpointing
using ChainRulesCore
using LinearAlgebra
using Enzyme
using Serialization
using HDF5

"""
Scheme
Expand Down Expand Up @@ -60,6 +62,26 @@ end

export Scheme, AbstractADTool, jacobian, @checkpoint, @checkpoint_struct, checkpoint_struct

function serialize(x)
s = IOBuffer()
Serialization.serialize(s, x)
take!(s)
end

function deserialize(x)
s = IOBuffer(x)
Serialization.deserialize(s)
end

export serialize, deserialize

abstract type AbstractStorage end

include("Storage/ArrayStorage.jl")
include("Storage/HDF5Storage.jl")

export AbstractStorage, ArrayStorage, HDF5Storage

include("deprecated.jl")
include("Schemes/Revolve.jl")
include("Schemes/Periodic.jl")
Expand Down
12 changes: 7 additions & 5 deletions src/Schemes/Periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@
Periodic checkpointing scheme.

"""
mutable struct Periodic <: Scheme
mutable struct Periodic{MT} <: Scheme where {MT}
steps::Int
acp::Int
period::Int
verbose::Int
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
end

function Periodic(
function Periodic{MT}(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand All @@ -36,7 +38,7 @@ function Periodic(
acp = checkpoints
period = div(steps, checkpoints)

periodic = Periodic(steps, acp, period, verbose, fstore, frestore)
periodic = Periodic{MT}(steps, acp, period, verbose, fstore, frestore, storage)

forwardcount(periodic)
return periodic
Expand All @@ -59,7 +61,7 @@ function checkpoint_struct(body::Function,
) where{MT}
model = deepcopy(model_input)
model_final = []
model_check_outer = Array{MT}(undef, alg.acp)
model_check_outer = alg.storage
model_check_inner = Array{MT}(undef, alg.period)
check = 0
for i = 1:alg.acp
Expand Down
16 changes: 11 additions & 5 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ A minor extension is the optional `bundle` parameter that allows to treat as ma
iterations in one tape/adjoint sweep. If `bundle` is 1, the default, then the behavior is that of Alg. 799.

"""
mutable struct Revolve <: Scheme
mutable struct Revolve{MT} <: Scheme where {MT}
steps::Int
bundle::Int
tail::Int
Expand All @@ -23,17 +23,19 @@ mutable struct Revolve <: Scheme
verbose::Int
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
storage::AbstractStorage
end

function Revolve(
function Revolve{MT}(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
storage::AbstractStorage = ArrayStorage{MT}(checkpoints),
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
) where {MT}
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand Down Expand Up @@ -69,7 +71,11 @@ function Revolve(
firstuturned = false
stepof = Vector{Int}(undef, acp+1)

revolve = Revolve(steps, bundle, tail, acp, cstart, cend, numfwd, numinv, numstore, rwcp, prevcend, firstuturned, stepof, verbose, fstore, frestore)
revolve = Revolve{MT}(
steps, bundle, tail, acp, cstart, cend, numfwd,
numinv, numstore, rwcp, prevcend, firstuturned,
stepof, verbose, fstore, frestore, storage
)

if verbose > 0
predfwdcnt = forwardcount(revolve)
Expand Down Expand Up @@ -366,7 +372,7 @@ function checkpoint_struct(body::Function,
model = deepcopy(model_input)
storemap = Dict{Int32,Int32}()
check = 0
model_check = Array{MT}(undef, alg.acp)
model_check = alg.storage
model_final = []
while true
next_action = next_action!(alg)
Expand Down
17 changes: 17 additions & 0 deletions src/Storage/ArrayStorage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct ArrayStorage{MT} <: AbstractStorage where {MT}
_storage::Array{MT}
end

function ArrayStorage{MT}(acp::Int) where {MT}
storage = Array{MT}(undef, acp)
return ArrayStorage(storage)
end

Base.getindex(storage::ArrayStorage{MT}, i) where {MT} = storage._storage[i]

function Base.setindex!(storage::ArrayStorage{MT}, v, i) where {MT}
storage._storage[i] = v
end

Base.ndims(::Type{ArrayStorage{MT}}) where {MT} = 1
Base.size(storage::ArrayStorage{MT}) where {MT} = size(storage._storage)
31 changes: 31 additions & 0 deletions src/Storage/HDF5Storage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
mutable struct HDF5Storage{MT} <: AbstractStorage where {MT}
fid::HDF5.File
filename::String
acp::Int64
end

function HDF5Storage{MT}(acp::Int; filename=tempname()) where {MT}
fid = h5open(filename, "w")
storage = HDF5Storage{MT}(fid, filename, acp)
function _finalizer(storage::HDF5Storage{MT})
close(storage.fid)
return storage
end
finalizer(_finalizer, storage)
return storage
end

function Base.getindex(storage::HDF5Storage{MT}, i)::MT where {MT}
@assert i >= 1 && i <= storage.acp
blob = read(storage.fid["$i"])
return deserialize(blob)
end

function Base.setindex!(storage::HDF5Storage{MT}, v::MT, i) where {MT}
@assert i >= 1 && i <= storage.acp
if haskey(storage.fid, "$i")
delete_object(storage.fid, "$i")
end
blob = serialize(v)
storage.fid["$i"] = blob
end
40 changes: 33 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ include("../examples/adtools.jl")
t = F_C[3,i]
return F_H, t
end
revolve = Revolve(steps, snaps, store, restore; verbose=info)
revolve = Revolve{Nothing}(steps, snaps, store, restore; verbose=info)
F_opt, F_final, L_opt, L = optcontrol(revolve, steps, adtool)
@test isapprox(F_opt, F_final, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -77,7 +77,7 @@ include("../examples/adtools.jl")
t = F_C[3,i]
return F_H, t
end
periodic = Periodic(steps, snaps, store, restore; verbose=info)
periodic = Periodic{Nothing}(steps, snaps, store, restore; verbose=info)
F_opt, F_final, L_opt, L = optcontrol(periodic, steps, adtool)
@test isapprox(F_opt, F_final, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -92,7 +92,7 @@ include("../examples/adtools.jl")
snaps = 3
info = 0

revolve = Revolve(steps, snaps; verbose=info)
revolve = Revolve{Model}(steps, snaps; verbose=info)
F, L, F_opt, L_opt = muoptcontrol(revolve, steps)
@test isapprox(F_opt, F, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
Expand All @@ -103,20 +103,20 @@ include("../examples/adtools.jl")
snaps = 4
info = 0

periodic = Periodic(steps, snaps; verbose=info)
periodic = Periodic{Model}(steps, snaps; verbose=info)
F, L, F_opt, L_opt = muoptcontrol(periodic, steps)
@test isapprox(F_opt, F, rtol=1e-4)
@test isapprox(L_opt, L, rtol=1e-4)
end
end
@testset "Test heat" begin
@testset "Test heat example" begin
include("../examples/heat.jl")
@testset "Testing Revolve..." begin
steps = 500
snaps = 4
info = 0

revolve = Revolve(steps, snaps; verbose=info)
revolve = Revolve{Heat}(steps, snaps; verbose=info)
T, dT = heat(revolve, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
Expand All @@ -128,7 +128,33 @@ include("../examples/adtools.jl")
snaps = 4
info = 0

periodic = Periodic(steps, snaps; verbose=info)
periodic = Periodic{Heat}(steps, snaps; verbose=info)
T, dT = heat(periodic, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
@test isapprox(norm(dT), 6.970279349365908, atol=1e-11)
end
end
@testset "Test HDF5 storage using heat example" begin
include("../examples/heat.jl")
@testset "Testing Revolve..." begin
steps = 500
snaps = 4
info = 0

revolve = Revolve{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info)
T, dT = heat(revolve, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
@test isapprox(norm(dT), 6.970279349365908, atol=1e-11)
end

@testset "Testing Periodic..." begin
steps = 500
snaps = 4
info = 0

periodic = Periodic{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info)
T, dT = heat(periodic, steps)

@test isapprox(norm(T), 66.21987468492061, atol=1e-11)
Expand Down