Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calibrate uses IntermediateResults. Optimization iterations are reported. #129

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,14 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data})
end

#--------------------------------------------------------------------# IntermediateResults callback
# Publish intermediate results to RabbitMQ with at least `every` seconds inbetween callbacks
# Publish intermediate results to RabbitMQ with at least `every` seconds in between callbacks
mutable struct IntermediateResults
last_callback::Dates.DateTime # Track the last time the callback was called
every::Dates.TimePeriod # Callback frequency e.g. `Dates.Second(5)`
id::String
iter::Int # Track how many iterations of the calibration have happened
function IntermediateResults(id::String; every=Dates.Second(0))
new(typemin(Dates.DateTime), every, id)
new(typemin(Dates.DateTime), every, id, 0)
end
end

Expand All @@ -134,6 +135,17 @@ function (o::IntermediateResults)(integrator)
EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false)
end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(p,lossval,ode_sol)
if o.last_callback + o.every ≤ Dates.now()
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
o.iter = o.iter + 1
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
end

return false
end
#----------------------------------------------------------------------# dataframe_with_observables
function dataframe_with_observables(sol::ODESolution)
sys = sol.prob.f.sys
Expand All @@ -158,8 +170,7 @@ function Simulate(o::OperationRequest)
end

function get_callback(o::OperationRequest, ::Type{Simulate})
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)),
save_positions = (false,false))
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)))
end

# callback for Simulate requests
Expand All @@ -184,11 +195,7 @@ end

# callback for Calibrate requests
function get_callback(o::OperationRequest, ::Type{Calibrate})
function (p,lossval,ode_sol)
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
end
IntermediateResults(o.id,every = Dates.Second(0))
end

function Calibrate(o::OperationRequest)
Expand Down
Loading