From 46078d3f0c91d1fb93fa49e58af0b45dde3e69c0 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Thu, 13 Jul 2023 13:40:33 -0400 Subject: [PATCH] add observables support, and a couple tests for ode_system_from_amr (#67) --- src/SimulationService.jl | 16 +++++++++++++--- test/runtests.jl | 11 +++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/SimulationService.jl b/src/SimulationService.jl index 6aa664e..1f86274 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -12,7 +12,7 @@ import InteractiveUtils: subtypes import JobSchedulers import JSON3 import MathML -import ModelingToolkit: @parameters, substitute, Differential, Num, @variables, ODESystem, ODEProblem +import ModelingToolkit: @parameters, substitute, Differential, Num, @variables, ODESystem, ODEProblem, structural_simplify import OpenAPI import Oxygen import SciMLBase: SciMLBase, DiscreteCallback, solve @@ -147,6 +147,11 @@ function ode_system_from_amr(obj::Config) statenames = [Symbol(s.id) for s in model.states] statevars = [only(@variables $s) for s in statenames] statefuncs = [only(@variables $s(t)) for s in statenames] + obsnames = [Symbol(o.id) for o in ode.observables] + obsvars = [only(@variables $o) for o in obsnames] + obsfuncs = [only(@variables $o(t)) for o in obsnames] + allvars = [statevars; obsvars] + allfuncs = [statefuncs; obsfuncs] # get parameter values and state initial values paramnames = [Symbol(x.id) for x in ode.parameters] @@ -171,10 +176,15 @@ function ode_system_from_amr(obj::Config) end end - subst = Dict(statevars .=> statefuncs) + subst = merge!(Dict(allvars .=> allfuncs), Dict(paramvars .=> paramvars)) eqs = [D(statef) ~ substitute(eqs[state], subst) for (state, statef) in (statenames .=> statefuncs)] - ODESystem(eqs, t, statefuncs, paramvars; defaults = [statefuncs .=> initial_vals; sym_defs], name=Symbol(obj.name)) + for (o, ofunc) in zip(ode.observables, obsfuncs) + expr = substitute(MathML.parse_str(o.expression_mathml), subst) + push!(eqs, ofunc ~ expr) + end + + structural_simplify(ODESystem(eqs, t, allfuncs, paramvars; defaults = [statefuncs .=> initial_vals; sym_defs], name=Symbol(obj.name))) end #-----------------------------------------------------------------------------# health: GET / diff --git a/test/runtests.jl b/test/runtests.jl index 8536568..bb315ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,11 +5,22 @@ using HTTP using JSON3 using Oxygen using SciMLBase: solve +using ModelingToolkit using SimulationService SimulationService.SIMSERVICE_ENABLE_TDS = false +#-----------------------------------------------------------------------------# AMR parsing +@testset "AMR parsing" begin + file = joinpath(@__DIR__, "..", "examples", "BIOMD0000000955_askenet.json") + amr = JSON3.read(read(file), Config) + sys = SimulationService.ode_system_from_amr(amr) + @test string.(states(sys)) == ["Susceptible(t)", "Diagnosed(t)", "Infected(t)", "Ailing(t)", "Recognized(t)", "Healed(t)", "Threatened(t)", "Extinct(t)"] + @test string.(parameters(sys)) == ["beta", "gamma", "delta", "alpha", "epsilon", "zeta", "lambda", "eta", "rho", "theta", "kappa", "mu", "nu", "xi", "tau", "sigma"] + @test map(x->string(x.lhs), observed(sys)) == ["Cases(t)", "Hospitalizations(t)", "Deaths(t)"] +end + #-----------------------------------------------------------------------------# Operations @testset "Operations" begin @testset "simulate" begin