Skip to content

Commit

Permalink
Added arg to solve_network to allow for early return of integrators.
Browse files Browse the repository at this point in the history
  • Loading branch information
joegilkes committed Mar 12, 2024
1 parent a76ff8b commit a68d507
Showing 1 changed file with 81 additions and 23 deletions.
104 changes: 81 additions & 23 deletions src/solving/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end


"""
sol = solve_network(method::StaticODESolve, sd, rd[, copy_network])
solve_network(method::StaticODESolve, sd, rd[, copy_network, return_integrator])
Solve a network with static kinetics.
Expand All @@ -87,10 +87,18 @@ Setting `copy_network=true` generates a `deepcopy` of
the original network in `rd` and `sd` and uses these in
the solution, to avoid side effects from calculators
modifying the original network that is passed in. The
copied (modified) network is retruned as part of the
copied (modified) network is returned as part of the
resulting `ODESolveOutput`.
Setting `return_integrator=true` sets up and returns the
underlying integrator without solving (i.e. at
`t = method.pars.tspan[1]`), allowing for manual stepping
through the solution. Note that chunkwise solutions implement
many reinitialisations of this integrator, which will have
to be mirrored in the calling script to get the same results.
"""
function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData; copy_network::Bool=false)
function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData;
copy_network=true, return_integrator=false)
if copy_network
sd_active = deepcopy(sd)
rd_active = deepcopy(rd)
Expand All @@ -106,13 +114,17 @@ function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData; copy

setup_network!(sd_active, rd_active, method.calculator)
split_method = method.pars.solve_chunks ? :chunkwise : :complete
sol = solve_network(method, sd_active, rd_active, Val(split_method))
sol = solve_network(method, sd_active, rd_active, Val(split_method), return_integrator)

res = ODESolveOutput(method, sol, sd_active, rd_active)
return res
if return_integrator
return sol
else
res = ODESolveOutput(method, sol, sd_active, rd_active)
return res
end
end

function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete})
function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -150,15 +162,19 @@ function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Va
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end

@info " - Solving network..."
@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early.\n"
return integ
end

adaptive_solve!(integ, method.pars, solvecall_kwargs; print_status=true)
@info " - Solved.\n"

return integ.sol
end

function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise})
function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -210,8 +226,13 @@ function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Va
if method.pars.ban_negatives
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end
@info " - Solving network..."
@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early."
@info " - Note that integrators for chunkwise solutions require significant work to fully solve outside of their intended solution methods.\n"
return integ
end

# Set up progress bar (if required).
if method.pars.progress
Expand All @@ -220,6 +241,7 @@ function solve_network(method::StaticODESolve, sd::SpeciesData, rd::RxData, ::Va
@info Progress(pbar_sid, name="Chunkwise ODE")
end
end
@info " - Solving network..."

# Loop over the solution chunks needed to generate the full solution.
for nc in 0:n_chunks_reqd-1
Expand Down Expand Up @@ -271,7 +293,7 @@ end


"""
sol = solve_network(method::VariableODESolve, sd, rd[, copy_network])
sol = solve_network(method::VariableODESolve, sd, rd[, copy_network, return_integrator])
Solve a network with variable kinetics.
Expand All @@ -283,10 +305,18 @@ Setting `copy_network=true` generates a `deepcopy` of
the original network in `rd` and `sd` and uses these in
the solution, to avoid side effects from calculators
modifying the original network that is passed in. The
copied (modified) network is retruned as part of the
copied (modified) network is returned as part of the
resulting `ODESolveOutput`.
Setting `return_integrator=true` sets up and returns the
underlying integrator without solving (i.e. at
`t = method.pars.tspan[1]`), allowing for manual stepping
through the solution. Note that chunkwise solutions implement
many reinitialisations of this integrator, which will have
to be mirrored in the calling script to get the same results.
"""
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData; copy_network::Bool=true)
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData;
copy_network=true, return_integrator=false)
if copy_network
sd_active = deepcopy(sd)
rd_active = deepcopy(rd)
Expand All @@ -307,14 +337,18 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData; co
setup_network!(sd_active, rd_active, method.calculator)
split_method = method.pars.solve_chunks ? :chunkwise : :complete
update_method = method.conditions.discrete_updates ? :discrete : :continuous
sol = solve_network(method, sd_active, rd_active, Val(split_method), Val(update_method))
sol = solve_network(method, sd_active, rd_active, Val(split_method), Val(update_method), return_integrator)

res = ODESolveOutput(method, sol, sd_active, rd_active)
return res
if return_integrator
return sol
else
res = ODESolveOutput(method, sol, sd_active, rd_active)
return res
end
end


function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete}, ::Val{:continuous})
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete}, ::Val{:continuous}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -394,14 +428,20 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end

@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early.\n"
return integ
end

adaptive_solve!(integ, method.pars, solvecall_kwargs; print_status=true)

return rebuild_vc_solution(integ.sol, gradient_profile_symbols)
end


function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise}, ::Val{:continuous})
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise}, ::Val{:continuous}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -496,8 +536,13 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
if method.pars.ban_negatives
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end
@info " - Solving network..."
@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early."
@info " - Note that integrators for chunkwise solutions require significant work to fully solve outside of their intended solution methods.\n"
return integ
end

# Set up progress bar (if required).
if method.pars.progress
Expand All @@ -506,6 +551,7 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
@info Progress(pbar_sid, name="Chunkwise ODE")
end
end
@info " - Solving network..."

# Loop over the solution chunks needed to generate the full solution.
for nc in 0:n_chunks_reqd-1
Expand Down Expand Up @@ -573,7 +619,7 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
return sol
end

function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete}, ::Val{:discrete})
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:complete}, ::Val{:discrete}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -618,14 +664,20 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end

@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early.\n"
return integ
end

adaptive_solve!(integ, method.pars, solvecall_kwargs; print_status=true)

return build_discrete_rate_solution(integ.sol, k_precalc)
end


function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise}, ::Val{:discrete})
function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::Val{:chunkwise}, ::Val{:discrete}, return_integrator)
@info " - Removing low-rate reactions"; flush_log()
apply_low_k_cutoff!(rd, method.calculator, method.pars, method.conditions)

Expand Down Expand Up @@ -682,8 +734,13 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
if method.pars.ban_negatives
solvecall_kwargs[:isoutofdomain] = (u,p,t)->any(x->x<0,u)
end
@info " - Solving network..."
@info " - Setting up integrator..."
integ = init(oprob, method.pars.solver; solvecall_kwargs...)
if return_integrator
@info " - Returning integrator early."
@info " - Note that integrators for chunkwise solutions require significant work to fully solve outside of their intended solution methods.\n"
return integ
end

# Set up progress bar (if required).
if method.pars.progress
Expand All @@ -692,6 +749,7 @@ function solve_network(method::VariableODESolve, sd::SpeciesData, rd::RxData, ::
@info Progress(pbar_sid, name="Chunkwise ODE")
end
end
@info " - Solving network..."

# Loop over the solution chunks needed to generate the full solution.
for nc in 0:n_chunks_reqd-1
Expand Down

0 comments on commit a68d507

Please sign in to comment.