Skip to content

Commit

Permalink
separate compute gradient and optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
enigne committed Mar 22, 2024
1 parent d9c630f commit 85d0b00
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 19 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand Down
23 changes: 11 additions & 12 deletions src/core/control.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@ using Optimization, OptimizationOptimJL
Enzyme.Compiler.RunAttributor[] = false

function Control_Core(md::model, femmodel::FemModel) #{{{
# Compute gradient
computeGradient(md, femmodel)
# α = md.inversion.independent
# n = length(α)
# optprob = OptimizationFunction(costfunction, Optimization.AutoEnzyme())#, grad=enzymerule)
# prob = Optimization.OptimizationProblem(optprob, α, femmodel, lb = -ones(n), ub = ones(n))
# sol = solve(prob, Optim.LBFGS())
#
# #Put gradient in results
# InputUpdateFromVectorx(femmodel, sol.u, GradientEnum, VertexSIdEnum)
# RequestedOutputsx(femmodel, [GradientEnum])
# solve for optimization
α = md.inversion.independent
n = length(α)
optprob = OptimizationFunction(costfunction, Optimization.AutoEnzyme())
prob = Optimization.OptimizationProblem(optprob, α, femmodel, lb=md.inversion.min_parameters, ub=md.inversion.max_parameters)
sol = Optimization.solve(prob, Optim.LBFGS())

#TODO: Put the solution back in results according to its Enum
#InputUpdateFromVectorx(femmodel, sol.u, GradientEnum, VertexSIdEnum)
#RequestedOutputsx(femmodel, [GradientEnum])
end#}}}
function computeGradient(md::model, femmodel::FemModel) #{{{
#independent variable
Expand All @@ -28,7 +27,7 @@ function computeGradient(md::model, femmodel::FemModel) #{{{
# zero ALL depth of the model, make sure we get correct gradient
dfemmodel = Enzyme.Compiler.make_zero(Base.Core.Typeof(femmodel), IdDict(), femmodel)
# compute the gradient
println("CALLING AUTODIFF, prepare to die...")
#println("CALLING AUTODIFF, prepare to die...")
@time autodiff(Enzyme.Reverse, costfunction, Duplicated(femmodel, dfemmodel), Duplicated(α, ∂J_∂α))

#Put gradient in results
Expand Down
8 changes: 6 additions & 2 deletions src/core/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ include("./utils.jl")
function solve(md::model, solution::Symbol) #{{{

#Process incoming string
if solution===:sb || solution===:Stressbalance
if solution===:sb || solution===:Stressbalance || solution===:grad
solutionkey = :StressbalanceSolution
elseif solution===:tr || solution===:Transient
solutionkey = :TransientSolution
Expand All @@ -41,7 +41,11 @@ function solve(md::model, solution::Symbol) #{{{

#Solve (FIXME: to be improved later...)
if (md.inversion.iscontrol) # solve inverse problem
Control_Core(md, femmodel)
if solution===:grad
computeGradient(md, femmodel)
else
Control_Core(md, femmodel)
end
else # otherwise forward problem
if(solutionkey===:StressbalanceSolution)
analysis = StressbalanceAnalysis()
Expand Down
5 changes: 4 additions & 1 deletion src/usr/classes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,14 @@ mutable struct Inversion
iscontrol::Bool
vx_obs::Vector{Float64}
vy_obs::Vector{Float64}
min_parameters::Vector{Float64}
max_parameters::Vector{Float64}
independent::Vector{Float64}
maxiter::Int64
independent_string::String
end
function Inversion() #{{{
return Inversion( false, Vector{Float64}(undef,0), Vector{Float64}(undef,0), Vector{Float64}(undef,0), "Friction")
return Inversion( false, Vector{Float64}(undef,0), Vector{Float64}(undef,0), Vector{Float64}(undef,0), Vector{Float64}(undef,0), Vector{Float64}(undef,0), 0, "Friction")
end# }}}
function Base.show(io::IO, this::Inversion)# {{{
IssmStructDisp(io, this)
Expand Down
4 changes: 2 additions & 2 deletions test/testad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module enzymeDiff
module enzymeDiff_grad_frictionC

using dJUICE
using MAT
Expand All @@ -22,7 +22,7 @@ md.inversion.iscontrol = 1
md.inversion.independent = md.friction.coefficient
md.inversion.independent_string = "FrictionCoefficient"

md = solve(md, :sb)
md = solve(md, :grad)

# compute gradient by finite differences at each node
addJ = md.results["StressbalanceSolution"]["Gradient"]
Expand Down
18 changes: 17 additions & 1 deletion test/testad2.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module enzymeDiff_grad_rheologyB

using dJUICE
using MAT
using Test
Expand All @@ -20,11 +22,23 @@ md.inversion.iscontrol = 1
md.inversion.independent = md.materials.rheology_B
md.inversion.independent_string = "MaterialsRheologyB"

md = solve(md, :sb)
md = solve(md, :grad)

# compute gradient by finite differences at each node
addJ = md.results["StressbalanceSolution"]["Gradient"]

@testset "Quick AD test with Cost function" begin
#Now call AD!
md.inversion.iscontrol = 1
md.inversion.independent = md.materials.rheology_B
md.inversion.independent_string = "MaterialsRheologyB"

α = md.inversion.independent
femmodel=dJUICE.ModelProcessor(md, :StressbalanceSolution)
J1 = dJUICE.costfunction(femmodel, α)
@test ~isnothing(J1)
end

@testset "AD results RheologyB" begin
α = md.inversion.independent
delta = 1e-8
Expand All @@ -40,3 +54,5 @@ addJ = md.results["StressbalanceSolution"]["Gradient"]
@test abs(dJ - addJ[i])< 1e-6
end
end

end

0 comments on commit 85d0b00

Please sign in to comment.