Skip to content

Commit

Permalink
Calibrate uses IntermediateResults. Optimization iterations are repor…
Browse files Browse the repository at this point in the history
…ted. (DARPA-ASKEM#129)
  • Loading branch information
jClugstor committed Oct 26, 2023
1 parent 20d48a0 commit 65a8d94
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,14 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data})
end

#--------------------------------------------------------------------# IntermediateResults callback
# Publish intermediate results to RabbitMQ with at least `every` seconds inbetween callbacks
# Publish intermediate results to RabbitMQ with at least `every` seconds in between callbacks
mutable struct IntermediateResults
last_callback::Dates.DateTime # Track the last time the callback was called
every::Dates.TimePeriod # Callback frequency e.g. `Dates.Second(5)`
id::String
iter::Int # Track how many iterations of the calibration have happened
function IntermediateResults(id::String; every=Dates.Second(0))
new(typemin(Dates.DateTime), every, id)
new(typemin(Dates.DateTime), every, id, 0)
end
end

Expand All @@ -134,8 +135,17 @@ function (o::IntermediateResults)(integrator)
EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false)
end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(p,lossval,ode_sol)
if o.last_callback + o.every Dates.now()
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
o.iter = o.iter + 1
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
end


return false
end
#----------------------------------------------------------------------# dataframe_with_observables
function dataframe_with_observables(sol::ODESolution)
sys = sol.prob.f.sys
Expand All @@ -160,8 +170,7 @@ function Simulate(o::OperationRequest)
end

function get_callback(o::OperationRequest, ::Type{Simulate})
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)),
save_positions = (false,false))
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)))
end

# callback for Simulate requests
Expand All @@ -186,11 +195,7 @@ end

# callback for Calibrate requests
function get_callback(o::OperationRequest, ::Type{Calibrate})
function (p,lossval,ode_sol)
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
end
IntermediateResults(o.id,every = Dates.Second(0))
end

function Calibrate(o::OperationRequest)
Expand Down

0 comments on commit 65a8d94

Please sign in to comment.