Skip to content

Commit

Permalink
Merge pull request #9 from Argonne-National-Laboratory/feature/online_r2
Browse files Browse the repository at this point in the history
Feature/online r2
  • Loading branch information
sriharikrishna authored Jul 14, 2022
2 parents 8804319 + c3b4e61 commit 6a5ee30
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 279 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
ChainRulesCore = "1.0"
DataStructures = "0.18"
Enzyme = "0.9"
HDF5 = "0.16"
julia = "1.7"
Expand Down
81 changes: 81 additions & 0 deletions examples/optcontrolwhile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# This is a Julia version of Solution of the optimal control problem
# based on code written by Andrea Walther. See:
# Walther, Andrea, and Narayanan, Sri Hari Krishna. Extending the Binomial Checkpointing
# Technique for Resilience. United States: N. p., 2016. https://www.osti.gov/biblio/1364654.

using Checkpointing
using ReverseDiff

include("optcontrolfunc.jl")

function header()
println("**************************************************************************")
println("* Solution of the optimal control problem *")
println("* *")
println("* J(y) = y_2(1) -> min *")
println("* s.t. dy_1/dt = 0.5*y_1(t) + u(t), y_1(0)=1 *")
println("* dy_2/dt = y_1(t)^2 + 0.5*u(t)^2 y_2(0)=0 *")
println("* *")
println("* the adjoints equations fulfill *")
println("* *")
println("* dl_1/dt = -0.5*l_1(t) - 2*y_1(t)*l_2(t) l_1(1)=0 *")
println("* dl_2/dt = 0 l_2(1)=1 *")
println("* *")
println("* with Revolve for Online and (Multi-Stage) Offline Checkpointing *")
println("* *")
println("**************************************************************************")

println("**************************************************************************")
println("* The solution of the optimal control problem above is *")
println("* *")
println("* y_1*(t) = (2*e^(3t)+e^3)/(e^(3t/2)*(2+e^3)) *")
println("* y_2*(t) = (2*e^(3t)-e^(6-3t)-2+e^6)/((2+e^3)^2) *")
println("* u*(t) = (2*e^(3t)-e^3)/(e^(3t/2)*(2+e^3)) *")
println("* l_1*(t) = (2*e^(3-t)-2*e^(2t))/(e^(t/2)*(2+e^3)) *")
println("* l_2*(t) = 1 *")
println("* *")
println("**************************************************************************")

return
end


function optcontrolwhile(scheme, steps, adtool=ReverseDiffADTool())
println( "\n STEPS -> number of time steps to perform")
println("SNAPS -> number of checkpoints")
println("INFO = 1 -> calculate only approximate solution")
println("INFO = 2 -> calculate approximate solution + takeshots")
println("INFO = 3 -> calculate approximate solution + all information ")
println(" ENTER: STEPS, SNAPS, INFO \n")


h = 1.0/steps
L = Array{Float64, 1}(undef, 2)
L_H = Array{Float64, 1}(undef, 2)

t = 0.0
F = [1.0, 0.0]
F_H = [0.0, 0.0]
i = 1
println("steps = ", steps)
#We are specifying the number of steps here to test the approach
#Any test for convergence can be used here instead
#The number of steps is not provided to the online checkpointing scheme
@checkpoint scheme adtool while i < steps
F_H = [F[1], F[2]]
F = advance(F_H,t,h)
t += h
i = i+1
end

F_opt = Array{Float64, 1}(undef, 2)
L_opt = Array{Float64, 1}(undef, 2)
opt_sol(F_opt,1.0)
opt_lambda(L_opt,0.0)
println("\n\n")
println("y_1*(1) = " , F_opt[1] , " y_2*(1) = " , F_opt[2])
println("y_1 (1) = " , F[1] , " y_2 (1) = " , F[2] , " \n\n")
println("l_1*(0) = " , L_opt[1] , " l_2*(0) = " , L_opt[2])
println("l_1 (0) = " , L[1] , " sl_2 (0) = " , L[2] , " ")
return F_opt, F, L_opt, L
end
2 changes: 2 additions & 0 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Checkpointing

using ChainRulesCore
using LinearAlgebra
using DataStructures
using Enzyme
using Serialization
using HDF5
Expand Down Expand Up @@ -89,6 +90,7 @@ include("Schemes/Online_r2.jl")

export Revolve, guess, factor, next_action!, ActionFlag, Periodic
export ReverseDiffADTool, ZygoteADTool, EnzymeADTool, ForwardDiffADTool, DiffractorADTool, jacobian
export Online_r2, update_revolve

@generated function copyto!(dest::MT, src::MT) where {MT}
assignments = [
Expand Down
Loading

0 comments on commit 6a5ee30

Please sign in to comment.