From d420d6ff01fff80dd92f51c0b7b3d636691d36d3 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 23 Oct 2023 11:14:05 -0400 Subject: [PATCH 1/3] fixed callback to return false, added opt counter --- src/SimulationService.jl | 3 +++ src/operations.jl | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/SimulationService.jl b/src/SimulationService.jl index 2443c51..eb3942b 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -58,6 +58,9 @@ const RABBITMQ_ROUTE = Ref{String}() const RABBITMQ_HOST = Ref{String}() const RABBITMQ_PORT = Ref{Int}() +# I don't like this, but needed for now to count optimization iterations +opt_callback_counter = Ref{Int}() + function __init__() if Threads.nthreads() == 1 @warn "SimulationService.jl expects `Threads.nthreads() > 1`. Use e.g. `julia --threads=auto`." diff --git a/src/operations.jl b/src/operations.jl index b0e51ff..62e3e15 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -187,7 +187,9 @@ function get_callback(o::OperationRequest, ::Type{Calibrate}) function (p,lossval,ode_sol) param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p) state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)]) - publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id) + opt_callback_counter[] = opt_callback_counter[] + 1 + publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id, iteration = opt_callback_counter[]) + return false end end @@ -214,6 +216,7 @@ function solve(o::Calibrate; callback) prob = ODEProblem(o.sys, [], o.timespan) statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)] + opt_iter_counter[] = 0 # bayesian datafit if o.calibrate_method == "bayesian" p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data; From 3eebdca9ffb3c2dee416c9f0162a322487f96b19 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 23 Oct 2023 11:16:49 -0400 Subject: [PATCH 2/3] typo --- src/operations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operations.jl b/src/operations.jl index 62e3e15..eff3ba6 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -216,7 +216,7 @@ function solve(o::Calibrate; callback) prob = ODEProblem(o.sys, [], o.timespan) statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)] - opt_iter_counter[] = 0 + opt_callback_counter[] = 0 # bayesian datafit if o.calibrate_method == "bayesian" p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data; From 40d4ce1cc54428d14bc7ea54d1e3218ab2100a7e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 24 Oct 2023 14:06:37 -0400 Subject: [PATCH 3/3] Using Intermediate results struct for calibrate. Calibrate now reports number of callbacks hit / calibration iterations. --- src/SimulationService.jl | 3 --- src/operations.jl | 28 ++++++++++++++++------------ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/SimulationService.jl b/src/SimulationService.jl index eb3942b..2443c51 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -58,9 +58,6 @@ const RABBITMQ_ROUTE = Ref{String}() const RABBITMQ_HOST = Ref{String}() const RABBITMQ_PORT = Ref{Int}() -# I don't like this, but needed for now to count optimization iterations -opt_callback_counter = Ref{Int}() - function __init__() if Threads.nthreads() == 1 @warn "SimulationService.jl expects `Threads.nthreads() > 1`. Use e.g. `julia --threads=auto`." diff --git a/src/operations.jl b/src/operations.jl index eff3ba6..c6ba2ce 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -110,13 +110,14 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data}) end #--------------------------------------------------------------------# IntermediateResults callback -# Publish intermediate results to RabbitMQ with at least `every` seconds inbetween callbacks +# Publish intermediate results to RabbitMQ with at least `every` seconds in between callbacks mutable struct IntermediateResults last_callback::Dates.DateTime # Track the last time the callback was called every::Dates.TimePeriod # Callback frequency e.g. `Dates.Second(5)` id::String + iter::Int # Track how many iterations of the calibration have happened function IntermediateResults(id::String; every=Dates.Second(0)) - new(typemin(Dates.DateTime), every, id) + new(typemin(Dates.DateTime), every, id, 0) end end @@ -134,6 +135,17 @@ function (o::IntermediateResults)(integrator) EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false) end +# Intermediate results functor for calibrate +function (o::IntermediateResults)(p,lossval,ode_sol) + if o.last_callback + o.every ≤ Dates.now() + param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p) + state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)]) + o.iter = o.iter + 1 + publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, params = param_dict, id=o.id) + end + + return false +end #----------------------------------------------------------------------# dataframe_with_observables function dataframe_with_observables(sol::ODESolution) sys = sol.prob.f.sys @@ -158,8 +170,7 @@ function Simulate(o::OperationRequest) end function get_callback(o::OperationRequest, ::Type{Simulate}) - DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)), - save_positions = (false,false)) + DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0))) end # callback for Simulate requests @@ -184,13 +195,7 @@ end # callback for Calibrate requests function get_callback(o::OperationRequest, ::Type{Calibrate}) - function (p,lossval,ode_sol) - param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p) - state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)]) - opt_callback_counter[] = opt_callback_counter[] + 1 - publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id, iteration = opt_callback_counter[]) - return false - end + IntermediateResults(o.id,every = Dates.Second(0)) end function Calibrate(o::OperationRequest) @@ -216,7 +221,6 @@ function solve(o::Calibrate; callback) prob = ODEProblem(o.sys, [], o.timespan) statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)] - opt_callback_counter[] = 0 # bayesian datafit if o.calibrate_method == "bayesian" p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data;