Skip to content

Commit

Permalink
Integrate online checkpointing with struct checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Sep 14, 2022
1 parent 766a297 commit 6fdf9e7
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 61 deletions.
39 changes: 35 additions & 4 deletions examples/heat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function advance(heat)
end


function sumheat(heat::Heat, chkpscheme::Scheme)
function sumheat_for(heat::Heat, chkpscheme::Scheme, tsteps::Int64)
# AD: Create shadow copy for derivatives
@checkpoint_struct chkpscheme heat for i in 1:tsteps
heat.Tlast .= heat.Tnext
Expand All @@ -31,22 +31,53 @@ function sumheat(heat::Heat, chkpscheme::Scheme)
return reduce(+, heat.Tnext)
end

function heat(scheme::Scheme, tsteps::Int)
function sumheat_while(heat::Heat, chkpscheme::Scheme, tsteps::Int64)
# AD: Create shadow copy for derivatives
heat.tsteps = 1
@checkpoint_struct chkpscheme heat while heat.tsteps <= tsteps
heat.Tlast .= heat.Tnext
advance(heat)
heat.tsteps += 1
end
return reduce(+, heat.Tnext)
end

function heat_for(scheme::Scheme, tsteps::Int)
n = 100
Δx=0.1
Δt=0.001
# Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2
λ = 0.5

# Create object from struct
# Create object from struct. tsteps is not needed for a for-loop
heat = Heat(zeros(n), zeros(n), n, λ, tsteps)

# Boundary conditions
heat.Tnext[1] = 20.0
heat.Tnext[end] = 0

# Compute gradient
g = Zygote.gradient(sumheat, heat, scheme)
g = Zygote.gradient(sumheat_for, heat, scheme, tsteps)

return heat.Tnext, g[1].Tnext[2:end-1]
end

function heat_while(scheme::Scheme, tsteps::Int)
n = 100
Δx=0.1
Δt=0.001
# Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2
λ = 0.5

# Create object from struct
heat = Heat(zeros(n), zeros(n), n, λ, 1)

# Boundary conditions
heat.Tnext[1] = 20.0
heat.Tnext[end] = 0

# Compute gradient
g = Zygote.gradient(sumheat_while, heat, scheme, tsteps)

return heat.Tnext, g[1].Tnext[2:end-1]
end
62 changes: 52 additions & 10 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ function jacobian(tobedifferentiated, F_H, ::AbstractADTool)
error("No AD tool interface implemented")
end

export Scheme, AbstractADTool, jacobian, @checkpoint, @checkpoint_struct, checkpoint_struct
export Scheme, AbstractADTool, jacobian
export @checkpoint, @checkpoint_struct, checkpoint_struct_for, checkpoint_struct_while

function serialize(x)
s = IOBuffer()
Expand Down Expand Up @@ -115,22 +116,43 @@ function create_tangent(shadowmodel::MT) where {MT}
end

function ChainRulesCore.rrule(
::typeof(Checkpointing.checkpoint_struct),
::typeof(Checkpointing.checkpoint_struct_for),
body::Function,
alg::Scheme,
model::MT,
shadowmodel::MT,
range::Function
) where {MT}
model_input = deepcopy(model)
# shadowmodel = deepcopy(model)
for i in 1:alg.steps
body(model)
end
function checkpoint_struct_pullback(dmodel)
copyto!(shadowmodel, dmodel)
model = checkpoint_struct(body, alg, model_input, shadowmodel)
model = checkpoint_struct_for(body, alg, model_input, shadowmodel, range)
dshadowmodel = create_tangent(shadowmodel)
return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent()
return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent(), NoTangent()
end
return model, checkpoint_struct_pullback
end

function ChainRulesCore.rrule(
::typeof(Checkpointing.checkpoint_struct_while),
body::Function,
alg::Scheme,
model::MT,
shadowmodel::MT,
condition::Function
) where {MT}
model_input = deepcopy(model)
while condition(model)
body(model)
end
function checkpoint_struct_pullback(dmodel)
copyto!(shadowmodel, dmodel)
model = checkpoint_struct_while(body, alg, model_input, shadowmodel, condition)
dshadowmodel = create_tangent(shadowmodel)
return NoTangent(), NoTangent(), NoTangent(), dshadowmodel, NoTangent(), NoTangent()
end
return model, checkpoint_struct_pullback
end
Expand All @@ -149,11 +171,28 @@ adjoints and is created here. It is supposed to be initialized by ChainRules.
"""
macro checkpoint_struct(alg, model, loop)
ex = quote
shadowmodel = deepcopy($model)
$model = checkpoint_struct($alg, $model, shadowmodel) do $model
$(loop.args[2])
if loop.head == :for
ex = quote
shadowmodel = deepcopy($model)
function range()
$(loop.args[1])
end
$model = checkpoint_struct_for($alg, $model, shadowmodel, range) do $model
$(loop.args[2])
end
end
elseif loop.head == :while
ex = quote
shadowmodel = deepcopy($model)
function condition($model)
$(loop.args[1])
end
$model = checkpoint_struct_while($alg, $model, shadowmodel, condition) do $model
$(loop.args[2])
end
end
else
error("Checkpointing.jl: Unknown loop construct.")
end
esc(ex)
end
Expand All @@ -173,7 +212,10 @@ are seeded and retrieved.
"""
macro checkpoint_struct(alg, model, shadowmodel, loop)
ex = quote
$model = checkpoint_struct($alg, $model, $shadowmodel) do $model
function range()
$(loop.args[1])
end
$model = checkpoint_struct_for($alg, $model, $shadowmodel, range) do $model
$(loop.args[2])
end
end
Expand Down
Loading

0 comments on commit 6fdf9e7

Please sign in to comment.