Skip to content

Commit

Permalink
Re-enable BackTracking (#1761)
Browse files Browse the repository at this point in the history
`BackTracking` as relaxation is now enabled again, with a thin wrapper
to reject it when the residual gets worse. Upstream issue:

SciML/OrdinaryDiffEq.jl#2442
  • Loading branch information
SouthEndMusic authored Aug 29, 2024
1 parent b80a79a commit 23bbbbc
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 42 deletions.
12 changes: 3 additions & 9 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "c2cb085c326f61a96abd1a295e6fa775c585beba"
project_hash = "a410a350a7b0c63bc6696029509aa68c14023275"

[[deps.ADTypes]]
git-tree-sha1 = "6778bcc27496dae5723ff37ee30af451db8b35fe"
Expand Down Expand Up @@ -1070,7 +1070,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface", "TimerOutputs"]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface"]
git-tree-sha1 = "3adb1e5945b5a6b1eaee754077f25ccc402edd7f"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.13.1"
Expand Down Expand Up @@ -1302,7 +1302,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.18"

[[deps.Ribasim]]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
path = "core"
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
version = "2024.10.0"
Expand Down Expand Up @@ -1669,12 +1669,6 @@ weakdeps = ["RecipesBase"]
[deps.TimeZones.extensions]
TimeZonesRecipesBaseExt = "RecipesBase"

[[deps.TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.24"

[[deps.TranscodingStreams]]
git-tree-sha1 = "d73336d81cafdc277ff45558bb7eaa2b04a8e472"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
Expand Down
2 changes: 2 additions & 0 deletions core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand Down Expand Up @@ -70,6 +71,7 @@ IOCapture = "0.2"
IterTools = "1.4"
JuMP = "1.15"
Legolas = "0.5"
LineSearches = "7"
LinearSolve = "2.24"
Logging = "<0.0.1, 1"
LoggingExtras = "1"
Expand Down
34 changes: 13 additions & 21 deletions core/ext/RibasimMakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
module RibasimMakieExt
using DataFrames: DataFrame
using Makie: Figure, Axis, lines!, axislegend
using Makie: Figure, Axis, scatterlines!, axislegend
using Ribasim: Ribasim, Model

function Ribasim.plot_basin_data!(model::Model, ax::Axis, column::Symbol)
basin_data = DataFrame(Ribasim.basin_table(model))
for node_id in unique(basin_data.node_id)
group = filter(:node_id => ==(node_id), basin_data)
lines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id")
scatterlines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id")
end

axislegend(ax)
Expand All @@ -23,31 +23,23 @@ function Ribasim.plot_basin_data(model::Model)
f
end

function Ribasim.plot_flow!(
model::Model,
ax::Axis,
edge_id::Int32;
skip_conservative_out = false,
)
function Ribasim.plot_flow!(model::Model, ax::Axis, edge_metadata::Ribasim.EdgeMetadata)
flow_data = DataFrame(Ribasim.flow_table(model))
flow_data = filter(:edge_id => ==(edge_id), flow_data)
first_row = first(flow_data)
# Skip outflows of conservative nodes because these are the same as the inflows
if skip_conservative_out &&
Ribasim.NodeType.T(first_row.from_node_type) in Ribasim.conservative_nodetypes
return nothing
end
label = "$(first_row.from_node_type) #$(first_row.from_node_id)$(first_row.to_node_type) #$(first_row.to_node_id)"
lines!(ax, flow_data.time, flow_data.flow_rate; label)
flow_data = filter(:edge_id => ==(edge_metadata.id), flow_data)
label = "$(edge_metadata.edge[1])$(edge_metadata.edge[2])"
scatterlines!(ax, flow_data.time, flow_data.flow_rate; label)
return nothing
end

function Ribasim.plot_flow(model::Model)
function Ribasim.plot_flow(model::Model; skip_conservative_out = true)
f = Figure()
ax = Axis(f[1, 1]; xlabel = "time", ylabel = "flow rate [m³s⁻¹]")
edge_ids = unique(Ribasim.flow_table(model).edge_id)
for edge_id in edge_ids
Ribasim.plot_flow!(model, ax, edge_id; skip_conservative_out = true)
for edge_metadata in values(model.integrator.p.graph.edge_data)
if skip_conservative_out &&
edge_metadata.edge[1].type in Ribasim.conservative_nodetypes
continue
end
Ribasim.plot_flow!(model, ax, edge_metadata)
end
axislegend(ax)
f
Expand Down
13 changes: 11 additions & 2 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ For more granular access, see:
module Ribasim

# Algorithms for solving ODEs.
using OrdinaryDiffEq: OrdinaryDiffEq, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, get_du
using OrdinaryDiffEq:
OrdinaryDiffEq,
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm,
get_du,
AbstractNLSolver,
relax!,
_compute_rhs!,
calculate_residuals!
using LineSearches: BackTracking

# Interface for defining and solving the ODE problem of the physical layer.
using SciMLBase:
Expand All @@ -31,7 +39,8 @@ using SciMLBase:
ODEProblem,
ODESolution,
VectorContinuousCallback,
get_proposed_dt
get_proposed_dt,
DEIntegrator

# Automatically detecting the sparsity pattern of the Jacobian of water_balance!
# through operator overloading
Expand Down
6 changes: 4 additions & 2 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ const algorithms = Dict{String, Type}(
)

"Create an OrdinaryDiffEqAlgorithm from solver config"
function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm
function algorithm(solver::Solver; u0 = [])::OrdinaryDiffEqAlgorithm
algotype = get(algorithms, solver.algorithm, nothing)
if algotype === nothing
options = join(keys(algorithms), ", ")
Expand All @@ -239,7 +239,9 @@ function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm
end
kwargs = Dict{Symbol, Any}()
if algotype <: OrdinaryDiffEqNewtonAdaptiveAlgorithm
kwargs[:nlsolve] = NLNewton(; relax = 0.1)
kwargs[:nlsolve] = NLNewton(;
relax = Ribasim.MonitoredBackTracking(; z_tmp = copy(u0), dz_tmp = copy(u0)),
)
end
# not all algorithms support this keyword
kwargs[:autodiff] = solver.autodiff
Expand Down
4 changes: 3 additions & 1 deletion core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ function Model(config_path::AbstractString)::Model
end

function Model(config::Config)::Model
alg = algorithm(config.solver)
db_path = input_path(config, config.database)
if !isfile(db_path)
@error "Database file not found" db_path
Expand Down Expand Up @@ -109,6 +108,9 @@ function Model(config::Config)::Model
u0 = ComponentVector{Float64}(; storage, integral)
du0 = zero(u0)

# The Solver algorithm
alg = algorithm(config.solver; u0)

# Synchronize level with storage
set_current_basin_properties!(parameters.basin, u0, du0)

Expand Down
10 changes: 7 additions & 3 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
error("Invalid Basin / profile table.")
end

level_to_area = LinearInterpolation.(area, level; extrapolate = true)
level_to_area =
LinearInterpolation.(area, level; extrapolate = true, cache_parameters = true)
storage_to_level = invert_integral.(level_to_area)

t_end = seconds_since(config.endtime, config.starttime)
Expand Down Expand Up @@ -921,6 +922,7 @@ function user_demand_static!(
fill(first_row.return_factor, 2),
return_factor_old.t;
extrapolate = true,
cache_parameters = true,
)
min_level[user_demand_idx] = first_row.min_level

Expand Down Expand Up @@ -1026,8 +1028,10 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand
]
demand_from_timeseries = fill(false, n_user)
allocated = fill(Inf, n_user, n_priority)
return_factor =
[LinearInterpolation(zeros(2), trivial_timespan) for i in eachindex(node_ids)]
return_factor = [
LinearInterpolation(zeros(2), trivial_timespan; cache_parameters = true) for
i in eachindex(node_ids)
]
min_level = zeros(n_user)

# Process static table
Expand Down
11 changes: 11 additions & 0 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,20 @@ function water_balance!(
# Formulate du (controlled by PidControl)
formulate_du_pid_controlled!(du, graph, pid_control)

# https://github.com/Deltares/Ribasim/issues/1705#issuecomment-2283293974
stop_declining_negative_storage!(du, u)

return nothing
end

function stop_declining_negative_storage!(du, u)
for (i, s) in enumerate(u.storage)
if s < 0
du.storage[i] = max(du.storage[i], 0.0)
end
end
end

function formulate_continuous_control!(du, p, t)::Nothing
(; compound_variable, target_ref, func) = p.continuous_control

Expand Down
82 changes: 79 additions & 3 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ end
Compute the area and level of a basin given its storage.
"""
function get_area_and_level(basin::Basin, state_idx::Int, storage::T)::Tuple{T, T} where {T}
level = basin.storage_to_level[state_idx](max(storage, 0.0))
area = basin.level_to_area[state_idx](level)

storage_to_level = basin.storage_to_level[state_idx]
level_to_area = basin.level_to_area[state_idx]
if storage >= 0
level = storage_to_level(storage)
else
# Negative storage is not feasible and this yields a level
# below the basin bottom, but this does yield usable gradients
# for the non-linear solver
bottom = first(level_to_area.t)
level = bottom + derivative(storage_to_level, 0.0) * storage
end
area = level_to_area(level)
return area, level
end

Expand Down Expand Up @@ -887,3 +896,70 @@ end
(A::AbstractInterpolation)(t::GradientTracer) = t
reduction_factor(x::GradientTracer, threshold::Real) = x
relaxed_root(x::GradientTracer, threshold::Real) = x
get_area_and_level(basin::Basin, state_idx::Int, storage::GradientTracer) = storage, storage
stop_declining_negative_storage!(du, u::ComponentVector{<:GradientTracer}) = nothing

@kwdef struct MonitoredBackTracking{B, V}
linesearch::B = BackTracking()
dz_tmp::V = []
z_tmp::V = []
end

"""
Compute the residual of the non-linear solver, i.e. a measure of the
error in the solution to the implicit equation defined by the solver algorithm
"""
function residual(z, integrator, nlsolver, f)
(; uprev, t, p, dt, opts, isdae) = integrator
(; tmp, ztmp, γ, α, cache, method) = nlsolver
(; ustep, atmp, tstep, k, invγdt, tstep, k, invγdt) = cache
if isdae
_uprev = get_dae_uprev(integrator, uprev)
b, ustep2 =
_compute_rhs!(tmp, ztmp, ustep, α, tstep, k, invγdt, p, _uprev, f::TF, z)
else
b, ustep2 =
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f, z)
end
calculate_residuals!(
atmp,
b,
uprev,
ustep2,
opts.abstol,
opts.reltol,
opts.internalnorm,
t,
)
ndz = opts.internalnorm(atmp, t)
return ndz
end

"""
MonitoredBackTracing is a thin wrapper of BackTracking, making sure that
the BackTracking relaxation is rejected if it results in a residual increase
"""
function OrdinaryDiffEq.relax!(
dz,
nlsolver::AbstractNLSolver,
integrator::DEIntegrator,
f,
linesearch::MonitoredBackTracking,
)
(; linesearch, dz_tmp, z_tmp) = linesearch

# Store step before relaxation
@. dz_tmp = dz

# Apply relaxation and measure the residual change
@. z_tmp = nlsolver.z + dz
resid_before = residual(z_tmp, integrator, nlsolver, f)
relax!(dz, nlsolver, integrator, f, linesearch)
@. z_tmp = nlsolver.z + dz
resid_after = residual(z_tmp, integrator, nlsolver, f)

# If the residual increased due to the relaxation, reject it
if resid_after > resid_before
@. dz = dz_tmp
end
end
1 change: 0 additions & 1 deletion core/test/main_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
@show backtrace
end
@test occursin("version in the TOML config file does not match", output)
@test occursin("Info: Convergence bottlenecks in descending order of severity:", output)
end

@testitem "main error logging" begin
Expand Down

0 comments on commit 23bbbbc

Please sign in to comment.