Skip to content

Commit

Permalink
Add nonlinear solver options as keyword argument for integrators.
Browse files Browse the repository at this point in the history
  • Loading branch information
michakraus committed Dec 3, 2024
1 parent 1c8cabf commit d82d9bb
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ForwardDiff = "0.10"
GenericLinearAlgebra = "0.2, 0.3"
GeometricBase = "0.10.11"
GeometricEquations = "0.18"
GeometricSolutions = "0.3"
GeometricSolutions = "0.3, 0.4"
OffsetArrays = "1"
Parameters = "0.12"
PrettyTables = "2"
Expand Down
9 changes: 5 additions & 4 deletions src/integrators/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ function GeometricIntegrator(
integratormethod::GeometricMethod,
solvermethod::SolverMethod,
iguess::Extrapolation;
options = default_options(),
method = initmethod(integratormethod, problem),
caches = CacheDict(problem, method),
solver = initsolver(solvermethod, method, caches)
solver = initsolver(solvermethod, options, method, caches)
)
GeometricIntegrator(problem, method, caches, solver, iguess)
end
Expand Down Expand Up @@ -190,8 +191,8 @@ function integrate!(sol::GeometricSolution, int::AbstractIntegrator)
end


function integrate(integrator::AbstractIntegrator; kwargs...)
solution = Solution(problem(integrator); kwargs...)
function integrate(integrator::AbstractIntegrator)
solution = Solution(problem(integrator))
integrate!(solution, integrator)
end

Expand All @@ -201,7 +202,7 @@ function integrate(problem::AbstractProblem, method::GeometricMethod; kwargs...)
end

function integrate(problems::EnsembleProblem, method::GeometricMethod; kwargs...)
solutions = Solution(problems; kwargs...)
solutions = Solution(problems)

for (problem, solution) in zip(problems, solutions)
integrator = GeometricIntegrator(problem, method; kwargs...)
Expand Down
4 changes: 2 additions & 2 deletions src/integrators/rk/integrators_dirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
Base.getindex(s::SingleStageSolvers, args...) = getindex(s.solvers, args...)


function initsolver(::NewtonMethod, method::DIRK, caches::CacheDict; kwargs...)
SingleStageSolvers([NewtonSolver(zero(cache(caches).x[i]), zero(cache(caches).x[i]); linesearch = Backtracking(), config = Options(min_iterations = 1)) for i in eachstage(method)]...)
function initsolver(::NewtonMethod, config::Options, method::DIRK, caches::CacheDict; kwargs...)
SingleStageSolvers([NewtonSolver(zero(cache(caches).x[i]), zero(cache(caches).x[i]); linesearch = Backtracking(), config = config) for i in eachstage(method)]...)
end


Expand Down
4 changes: 2 additions & 2 deletions src/integrators/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ default_options() = Options(
f_abstol = 8eps(),
)

initsolver(::SolverMethod, ::GeometricMethod, ::CacheDict) = NoSolver()
initsolver(::SolverMethod, ::Options, ::GeometricMethod, ::CacheDict) = NoSolver()

# create nonlinear solver
function initsolver(::NewtonMethod, ::GeometricMethod, caches::CacheDict; config = default_options(), kwargs...)
function initsolver(::NewtonMethod, config::Options, ::GeometricMethod, caches::CacheDict; kwargs...)
x = zero(nlsolution(caches))
y = zero(nlsolution(caches))
NewtonSolver(x, y; linesearch = Backtracking(), config = config, kwargs...)
Expand Down
7 changes: 4 additions & 3 deletions src/integrators/splitting/composition_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ methods(c::Composition{<: GeometricMethod}, neqs) = Tuple(method(c) for _ in 1:n

splitting(c::Composition) = c.splitting


_options(methods::Tuple) = Tuple([default_options() for _ in methods])
_solvers(methods::Tuple) = Tuple([default_solver(m) for m in methods])
_iguesses(methods::Tuple) = Tuple([default_iguess(m) for m in methods])

Expand All @@ -66,16 +66,17 @@ struct CompositionIntegrator{
problem::SODEProblem,
splitting::AbstractSplittingMethod,
methods::Tuple;
options = _options(methods),
solvers = _solvers(methods),
initialguesses = _iguesses(methods))

@assert length(methods) == length(solvers) == length(initialguesses) == _neqs(problem)
@assert length(methods) == length(options) == length(solvers) == length(initialguesses) == _neqs(problem)

# get splitting indices and coefficients
f, c = coefficients(problem, splitting)

# construct composition integrators
subints = Tuple(GeometricIntegrator(SubstepProblem(problem, c[i], f[i]), methods[f[i]]) for i in eachindex(f,c))
subints = Tuple(GeometricIntegrator(SubstepProblem(problem, c[i], f[i]), methods[f[i]]; options = options[f[i]], solver = solvers[f[i]]) for i in eachindex(f,c))

new{typeof(splitting), typeof(problem), typeof(subints)}(problem, splitting, subints)
end
Expand Down
3 changes: 2 additions & 1 deletion src/projections/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ function ProjectionIntegrator(
solvermethod::SolverMethod,
iguess::Extrapolation,
subint::AbstractIntegrator;
options = default_options(),
method = initmethod(projectionmethod, problem),
caches = CacheDict(problem, method),
solver = initsolver(solvermethod, method, caches)
solver = initsolver(solvermethod, options, method, caches)
)
ProjectionIntegrator(problem, method, caches, solver, iguess, subint)
end
Expand Down
4 changes: 2 additions & 2 deletions src/projections/standard_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ function split_nlsolution(x::AbstractVector, int::StandardProjectionIntegrator)
end


function initsolver(::NewtonMethod, ::ProjectedMethod{<:StandardProjection}, caches::CacheDict; kwargs...)
function initsolver(::NewtonMethod, config::Options, ::ProjectedMethod{<:StandardProjection}, caches::CacheDict; kwargs...)
x̄, x̃ = split_nlsolution(cache(caches))
NewtonSolver(zero(x̃), zero(x̃); kwargs...)
NewtonSolver(zero(x̃), zero(x̃); config = config, kwargs...)
end


Expand Down
4 changes: 2 additions & 2 deletions src/spark/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ hasnullvector(method::AbstractSPARKMethod) = hasnullvector(tableau(method))


# create nonlinear solver
function initsolver(::Newton, method::AbstractSPARKMethod, caches::CacheDict)
NewtonSolver(zero(nlsolution(caches)), zero(nlsolution(caches)); linesearch = Backtracking(), config = Options(min_iterations = 1, x_abstol = 8eps(), f_abstol = 8eps()))
function initsolver(::Newton, config::Options, method::AbstractSPARKMethod, caches::CacheDict)
NewtonSolver(zero(nlsolution(caches)), zero(nlsolution(caches)); linesearch = Backtracking(), config = config)
end
4 changes: 2 additions & 2 deletions src/spark/integrators_slrk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ F^1_{n,i} + F^2_{n,i} &= \frac{\partial L}{\partial q} (Q_{n,i}, V_{n,i}) , & i
const IntegratorSLRK{DT,TT} = GeometricIntegrator{<:LDAEProblem{DT,TT}, <:SLRK}


# function Integrators.initsolver(::Newton, solstep::SolutionStepPDAE{DT}, problem::LDAEProblem, method::SLRK, caches::CacheDict) where {DT}
# function Integrators.initsolver(::Newton, config::Options, solstep::SolutionStepPDAE{DT}, problem::LDAEProblem, method::SLRK, caches::CacheDict) where {DT}
# # create wrapper function f!(b,x)
# f! = (b,x) -> residual!(b, x, solstep, problem, method, caches)

# # create nonlinear solver
# NewtonSolver(zero(caches[DT].x), zero(caches[DT].x), f!; linesearch = Backtracking(), config = Options(min_iterations = 1, x_abstol = 8eps(), f_abstol = 8eps()))
# NewtonSolver(zero(caches[DT].x), zero(caches[DT].x), f!; linesearch = Backtracking(), config = config)
# end


Expand Down
9 changes: 9 additions & 0 deletions test/integrators/euler_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using GeometricIntegrators
using GeometricProblems.HarmonicOscillator
using SimpleSolvers: Options
using Test


Expand All @@ -16,4 +17,12 @@ ref = exact_solution(ode)
err = relative_maximum_error(sol, ref)
@test err.q < 5E-2

sol = integrate(ode, ImplicitEuler(); options = Options(
min_iterations = 1,
x_abstol = 2eps(),
f_abstol = 2eps(),
))
err = relative_maximum_error(sol, ref)
@test err.q < 5E-2

end

0 comments on commit d82d9bb

Please sign in to comment.