Skip to content

Commit

Permalink
Merge pull request #22 from DJ4Earth/upgrade_enzyme_0_13
Browse files Browse the repository at this point in the history
upgrade to enzyme 0.13
  • Loading branch information
enigne authored Nov 6, 2024
2 parents 5a2af0d + d016b8e commit 2fa48ca
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@ authors = ["Mathieu Morlighem <[email protected]>", "Gong Cheng <g
version = "0.1.1"

[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Triangulate = "f7e6ffb2-c36d-4f8f-a77e-16e897189344"

[compat]
ColorSchemes = "3.25"
Enzyme = "0.12"
Enzyme = "0.13"
Flux = "0.14"
MAT = "0.10"
Makie = "0.21"
SparseArrays = "1.8.0"
SparseArrays = "1"
StatsBase = "0.34"
Triangulate = "2.3"
julia = "1.8"
julia = "1.10"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
9 changes: 6 additions & 3 deletions src/core/control.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function ComputeGradient(∂J_∂α::Vector{Float64}, α::Vector{Float64}, femmo
# zero ALL depth of the model, make sure we get correct gradient
dfemmodel = Enzyme.Compiler.make_zero(Base.Core.Typeof(femmodel), IdDict(), femmodel)
# compute the gradient
autodiff(Enzyme.Reverse, costfunction, Active, Duplicated(α, ∂J_∂α), Duplicated(femmodel,dfemmodel))
autodiff(set_runtime_activity(Enzyme.Reverse), costfunction, Active, Duplicated(α, ∂J_∂α), Duplicated(femmodel,dfemmodel))
end#}}}
function CostFunctionx(femmodel::FemModel, α::Vector{Float64}, controlvar_enum::IssmEnum, SId_enum::IssmEnum, cost_enum_list::Vector{IssmEnum}, ::Val{solutionstring}) where solutionstring #{{{
#Update FemModel accordingly
Expand Down Expand Up @@ -61,7 +61,7 @@ end#}}}

# cost function handler for autodiff
function costfunction::Vector{Float64}, femmodel::FemModel) #{{{
# get the md.inversion.control_string
# get the md.inversion.independent_string
control_string = FindParam(String, femmodel.parameters, InversionControlParametersEnum)
# get the Enum
controlvar_enum = StringToEnum(control_string)
Expand All @@ -71,7 +71,10 @@ function costfunction(α::Vector{Float64}, femmodel::FemModel) #{{{

# get the cost function list from md.inversion.dependent_string
cost_list = FindParam(Vector{String}, femmodel.parameters, InversionCostFunctionsEnum)
cost_enum_list = map(StringToEnum, cost_list)
cost_enum_list = Vector{IssmEnum}(undef, length(cost_list))
for (index, value) in enumerate(cost_list)
cost_enum_list[index] = StringToEnum(value)
end

# compute cost function
# TODO: loop through all controls with respect to all the components in the cost function
Expand Down
3 changes: 1 addition & 2 deletions src/core/modules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ function InputUpdateFromSolutionx(analysis::Analysis,ug::IssmVector,femmodel::Fe
InputUpdateFromSolution(analysis,ug.vector,femmodel.elements[i])
end

return ug

return Nothing
end#}}}
function InputUpdateFromVectorx(femmodel::FemModel, vector::Vector{Float64}, enum::IssmEnum, layout::IssmEnum)# {{{

Expand Down
9 changes: 5 additions & 4 deletions test/testoptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ md.inversion.dependent_string = ["SurfaceAbsVelMisfit"]
femmodel=DJUICE.ModelProcessor(md, :StressbalanceSolution)
n = length(α)

DJUICE.costfunction(α, femmodel)
# test Enzyme autodiff only
dfemmodel = Enzyme.Compiler.make_zero(Base.Core.Typeof(femmodel), IdDict(), femmodel)
autodiff(Enzyme.Reverse, DJUICE.costfunction, Active, Duplicated(α, ∂J_∂α), Duplicated(femmodel,dfemmodel))
autodiff(set_runtime_activity(Enzyme.Reverse), DJUICE.costfunction, Active, Duplicated(α, ∂J_∂α), Duplicated(femmodel,dfemmodel))

# use user defined grad, errors!
optprob = OptimizationFunction(DJUICE.costfunction, Optimization.AutoEnzyme())
#optprob = OptimizationFunction(DJUICE.costfunction, Optimization.AutoEnzyme())
#prob = Optimization.OptimizationProblem(optprob, α, femmodel, lb=md.inversion.min_parameters, ub=md.inversion.max_parameters)
prob = Optimization.OptimizationProblem(optprob, α, femmodel)
sol = Optimization.solve(prob, Optimization.LBFGS())
#prob = Optimization.OptimizationProblem(optprob, α, femmodel)
#sol = Optimization.solve(prob, Optimization.LBFGS())
#sol = Optimization.solve(prob, Optim.GradientDescent())
#sol = Optimization.solve(prob, Optim.NelderMead())

0 comments on commit 2fa48ca

Please sign in to comment.