Skip to content

Commit

Permalink
Merge pull request #19 from JuliaGNI/revise-initialstate-handling
Browse files Browse the repository at this point in the history
Revise initialstate handling, simplify constructors and make similar methods more flexible.
  • Loading branch information
michakraus authored May 7, 2024
2 parents 887c292 + de22636 commit 291b4c4
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 120 deletions.
16 changes: 16 additions & 0 deletions src/geometric_equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ GeometricBase.arrtype(equ::GeometricEquation, ics::NamedTuple) = error("arrtype(
check_initial_conditions(equ::GeometricEquation, ics::NamedTuple) = error("check_initial_conditions(::GeometricEquation, ::NamedTuple) not implemented for ", typeof(equ), ".")
check_methods(equ::GeometricEquation, tspan, ics, params) = error("check_methods(::GeometricEquation, ::Tuple, ::NamedTuple, ::OptionalParameters) not implemented for ", typeof(equ), ".")

function initialstate(::GeometricEquation, ics::NamedTuple)
for s in ics
@assert typeof(s) <: Union{AlgebraicVariable, StateVariable}
end

return ics
end

function initialstate(equ::GeometricEquation, ics::AbstractVector{<:NamedTuple})
for ic in ics
initialstate(equ, ic)
end

return ics
end

function check_parameters(equ::GeometricEquation, params::NamedTuple)
typeof(parameters(equ)) <: NamedTuple || return false
keys(parameters(equ)) == keys(params) || return false
Expand Down
36 changes: 18 additions & 18 deletions src/odes/hode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,20 @@ function Base.show(io::IO, equation::HODE)
print(io, " ", invariants(equation))
end

function initialstate(::HODE, q₀::InitialState, p₀::InitialState)
(
q = _statevariable(q₀),
p = _statevariable(p₀),
)
end

function initialstate(::HODE, q₀::InitialStateVector, p₀::InitialStateVector)
[(
q = _statevariable(q),
p = _statevariable(p),
) for (q,p) in zip(q₀,p₀)]
end

function check_initial_conditions(::HODE, ics::NamedTuple)
haskey(ics, :q) || return false
haskey(ics, :p) || return false
Expand Down Expand Up @@ -199,17 +213,12 @@ $(hode_functions)
"""
const HODEProblem = EquationProblem{HODE}

function HODEProblem(v, f, hamiltonian, tspan, tstep, ics::NamedTuple;
function HODEProblem(v, f, hamiltonian, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = HODE(v, f, hamiltonian, invariants, parameter_types(parameters), periodicity)
EquationProblem(equ, tspan, tstep, ics, parameters)
end

function HODEProblem(v, f, hamiltonian, tspan, tstep, q₀::InitialState, p₀::InitialState; kwargs...)
ics = (q = _statevariable(q₀), p = _statevariable(p₀))
HODEProblem(v, f, hamiltonian, tspan, tstep, ics; kwargs...)
EquationProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function GeometricBase.periodicity(prob::HODEProblem)
Expand Down Expand Up @@ -245,19 +254,10 @@ For possible keyword arguments see the documentation on [`EnsembleProblem`](@ref
"""
const HODEEnsemble = EnsembleProblem{HODE}

function HODEEnsemble(v, f, hamiltonian, tspan, tstep, ics::InitialConditions;
function HODEEnsemble(v, f, hamiltonian, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = HODE(v, f, hamiltonian, invariants, parameter_types(parameters), periodicity)
EnsembleProblem(equ, tspan, tstep, ics, parameters)
end

function HODEEnsemble(v, f, hamiltonian, tspan, tstep, q₀::ISV, p₀::ISV; kwargs...) where {ISV <: InitialStateVector}
_ics = [(q = _statevariable(q), p = _statevariable(p)) for (q,p) in zip(q₀,p₀)]
HODEEnsemble(v, f, hamiltonian, tspan, tstep, _ics; kwargs...)
end

function HODEEnsemble(v, f, hamiltonian, tspan, tstep, q₀::IS, p₀::IS; kwargs...) where {IS <: InitialState}
HODEEnsemble(v, f, hamiltonian, tspan, tstep, (q = _statevariable(q₀), p = _statevariable(p₀)); kwargs...)
EnsembleProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end
42 changes: 22 additions & 20 deletions src/odes/iode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ function Base.show(io::IO, equation::IODE)
print(io, " ", invariants(equation))
end

function initialstate(::IODE, q₀::InitialState, p₀::InitialState, λ₀::InitialAlgebraic = zeroalgebraic(q₀))
(
q = _statevariable(q₀),
p = _statevariable(p₀),
λ = _algebraicvariable(λ₀),
)
end

function initialstate(::IODE, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀))
[(
q = _statevariable(q),
p = _statevariable(p),
λ = _algebraicvariable(λ)
) for (q,p,λ) in zip(q₀,p₀,λ₀)]
end

function check_initial_conditions(::IODE, ics::NamedTuple)
haskey(ics, :q) || return false
haskey(ics, :p) || return false
Expand Down Expand Up @@ -263,18 +279,13 @@ values `v̄ = _iode_default_v̄` and `f̄ = f`.
"""
const IODEProblem = EquationProblem{IODE}

function IODEProblem(ϑ, f, g, tspan, tstep, ics::NamedTuple;
function IODEProblem(ϑ, f, g, tspan::Tuple, tstep::Real, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity(),
= _iode_default_v̄, f̄ = f)
equ = IODE(ϑ, f, g, v̄, f̄, invariants, parameter_types(parameters), periodicity)
EquationProblem(equ, tspan, tstep, ics, parameters)
end

function IODEProblem(ϑ, f, g, tspan, tstep, q₀::InitialState, p₀::InitialState, λ₀::InitialAlgebraic = zeroalgebraic(q₀); kwargs...)
ics = (q = _statevariable(q₀), p = _statevariable(p₀), λ = _algebraicvariable(λ₀))
IODEProblem(ϑ, f, g, tspan, tstep, ics; kwargs...)
EquationProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function IODEProblem(ϑ, f, args...; kwargs...)
Expand Down Expand Up @@ -326,26 +337,17 @@ For possible keyword arguments see the documentation on [`EnsembleProblem`](@ref
"""
const IODEEnsemble = EnsembleProblem{IODE}

function IODEEnsemble(ϑ, f, g, tspan, tstep, ics::InitialConditions;
function IODEEnsemble(ϑ, f, g, tspan::Tuple, tstep::Real, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity(),
= _iode_default_v̄,
= f
)
equ = IODE(ϑ, f, g, v̄, f̄, invariants, parameter_types(parameters), periodicity)
EnsembleProblem(equ, tspan, tstep, ics, parameters)
end

function IODEEnsemble(ϑ, f, tspan::Tuple, tstep::Real, ics::InitialConditions; kwargs...)
IODEEnsemble(ϑ, f, _iode_default_g, tspan, tstep, ics; kwargs...)
end

function IODEEnsemble(ϑ, f, g, tspan::Tuple, tstep::Real, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀); kwargs...)
_ics = [(q = _statevariable(q), p = _statevariable(p), λ = _algebraicvariable(λ)) for (q,p,λ) in zip(q₀,p₀,λ₀)]
IODEEnsemble(ϑ, f, g, tspan, tstep, _ics; kwargs...)
EnsembleProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function IODEEnsemble(ϑ, f, tspan::Tuple, tstep::Real, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀); kwargs...)
IODEEnsemble(ϑ, f, _iode_default_g, tspan, tstep, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector; kwargs...)
function IODEEnsemble(ϑ, f, args...; kwargs...)
IODEEnsemble(ϑ, f, _iode_default_g, args...; kwargs...)
end
46 changes: 22 additions & 24 deletions src/odes/lode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ function Base.show(io::IO, equation::LODE)
print(io, " ", invariants(equation))
end

function initialstate(::LODE, q₀::InitialState, p₀::InitialState, λ₀::InitialAlgebraic = zeroalgebraic(q₀))
(
q = _statevariable(q₀),
p = _statevariable(p₀),
λ = _algebraicvariable(λ₀),
)
end

function initialstate(::LODE, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀))
[(
q = _statevariable(q),
p = _statevariable(p),
λ = _algebraicvariable(λ)
) for (q,p,λ) in zip(q₀,p₀,λ₀)]
end

function check_initial_conditions(::LODE, ics::NamedTuple)
haskey(ics, :q) || return false
haskey(ics, :v) || haskey(ics, :p) || return false
Expand Down Expand Up @@ -298,22 +314,13 @@ values `v̄ = _lode_default_v̄` and `f̄ = f`.
"""
const LODEProblem = EquationProblem{LODE}

function LODEProblem(ϑ, f, g, ω, l, tspan::Tuple, tstep::Real, ics::NamedTuple;
function LODEProblem(ϑ, f, g, ω, l, tspan::Tuple, tstep::Real, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity(),
= _lode_default_v̄, f̄ = f)
equ = LODE(ϑ, f, g, ω, v̄, f̄, l, invariants, parameter_types(parameters), periodicity)
EquationProblem(equ, tspan, tstep, ics, parameters)
end

function LODEProblem(ϑ, f, ω, l, tspan::Tuple, tstep::Real, ics::NamedTuple; kwargs...)
LODEProblem(ϑ, f, _lode_default_g, ω, l, tspan, tstep, ics; kwargs...)
end

function LODEProblem(ϑ, f, g, ω, l, tspan::Tuple, tstep::Real, q₀::InitialState, p₀::InitialState, λ₀::InitialAlgebraic = zeroalgebraic(q₀); kwargs...)
ics = (q = _statevariable(q₀), p = _statevariable(p₀), λ = _algebraicvariable(λ₀))
LODEProblem(ϑ, f, g, ω, l, tspan, tstep, ics; kwargs...)
EquationProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function LODEProblem(ϑ, f, ω, l, args...; kwargs...)
Expand Down Expand Up @@ -362,26 +369,17 @@ values `v̄ = _lode_default_v̄` and `f̄ = f`.
"""
const LODEEnsemble = EnsembleProblem{LODE}

function LODEEnsemble(ϑ, f, g, ω, l, tspan, tstep, ics::InitialConditions;
function LODEEnsemble(ϑ, f, g, ω, l, tspan::Tuple, tstep::Real, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity(),
= _lode_default_v̄,
= f
)
equ = LODE(ϑ, f, g, ω, v̄, f̄, l, invariants, parameter_types(parameters), periodicity)
EnsembleProblem(equ, tspan, tstep, ics, parameters)
end

function LODEEnsemble(ϑ, f, ω, l, tspan::Tuple, tstep::Real, ics::InitialConditions; kwargs...)
LODEEnsemble(ϑ, f, _lode_default_g, ω, l, tspan, tstep, ics; kwargs...)
end

function LODEEnsemble(ϑ, f, g, ω, l, tspan::Tuple, tstep::Real, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀); kwargs...)
_ics = [(q = _statevariable(q), p = _statevariable(p), λ = _algebraicvariable(λ)) for (q,p,λ) in zip(q₀,p₀,λ₀)]
LODEEnsemble(ϑ, f, g, ω, l, tspan, tstep, _ics; kwargs...)
EnsembleProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function LODEEnsemble(ϑ, f, ω, l, tspan::Tuple, tstep::Real, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector = zeroalgebraic(q₀); kwargs...)
LODEEnsemble(ϑ, f, _lode_default_g, ω, l, tspan, tstep, q₀::InitialStateVector, p₀::InitialStateVector, λ₀::InitialAlgebraicVector; kwargs...)
function LODEEnsemble(ϑ, f, ω, l, args...; kwargs...)
LODEEnsemble(ϑ, f, _lode_default_g, ω, l, args...; kwargs...)
end
34 changes: 16 additions & 18 deletions src/odes/ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ function Base.show(io::IO, equation::ODE)
print(io, " ", invariants(equation))
end

function initialstate(::ODE, q₀::InitialState)
(
q = _statevariable(q₀),
)
end

function initialstate(::ODE, q₀::InitialStateVector)
[(
q = _statevariable(q),
) for q in q₀]
end

function check_initial_conditions(::ODE, ics::NamedTuple)
haskey(ics, :q) || return false
typeof(ics.q) <: StateVariable || return false
Expand Down Expand Up @@ -156,17 +168,12 @@ $(ode_functions)
"""
const ODEProblem = EquationProblem{ODE}

function ODEProblem(v, tspan, tstep, ics::NamedTuple;
function ODEProblem(v, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = ODE(v, invariants, parameter_types(parameters), periodicity)
EquationProblem(equ, tspan, tstep, ics, parameters)
end

function ODEProblem(v, tspan, tstep, q₀::InitialState; kwargs...)
ics = (q = _statevariable(q₀),)
ODEProblem(v, tspan, tstep, ics; kwargs...)
EquationProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

GeometricBase.periodicity(prob::ODEProblem) = (q = periodicity(equation(prob)),)
Expand Down Expand Up @@ -200,19 +207,10 @@ For possible keyword arguments see the documentation on [`EnsembleProblem`](@ref
"""
const ODEEnsemble = EnsembleProblem{ODE}

function ODEEnsemble(v, tspan, tstep, ics::InitialConditions;
function ODEEnsemble(v, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = ODE(v, invariants, parameter_types(parameters), periodicity)
EnsembleProblem(equ, tspan, tstep, ics, parameters)
end

function ODEEnsemble(v, tspan, tstep, q₀::InitialStateVector; kwargs...)
_ics = [(q = _statevariable(q),) for q in q₀]
ODEEnsemble(v, tspan, tstep, _ics; kwargs...)
end

function ODEEnsemble(v, tspan, tstep, q₀::InitialState; kwargs...)
ODEEnsemble(v, tspan, tstep, (q = _statevariable(q₀),); kwargs...)
EnsembleProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end
36 changes: 18 additions & 18 deletions src/odes/pode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ function Base.show(io::IO, equation::PODE)
print(io, " ", invariants(equation))
end

function initialstate(::PODE, q₀::InitialState, p₀::InitialState)
(
q = _statevariable(q₀),
p = _statevariable(p₀),
)
end

function initialstate(::PODE, q₀::InitialStateVector, p₀::InitialStateVector)
[(
q = _statevariable(q),
p = _statevariable(p),
) for (q,p) in zip(q₀,p₀)]
end

function check_initial_conditions(::PODE, ics::NamedTuple)
haskey(ics, :q) || return false
haskey(ics, :p) || return false
Expand Down Expand Up @@ -169,17 +183,12 @@ $(pode_functions)
"""
const PODEProblem = EquationProblem{PODE}

function PODEProblem(v, f, tspan, tstep, ics::NamedTuple;
function PODEProblem(v, f, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = PODE(v, f, invariants, parameter_types(parameters), periodicity)
EquationProblem(equ, tspan, tstep, ics, parameters)
end

function PODEProblem(v, f, tspan, tstep, q₀::InitialState, p₀::InitialState; kwargs...)
ics = (q = _statevariable(q₀), p = _statevariable(p₀))
PODEProblem(v, f, tspan, tstep, ics; kwargs...)
EquationProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end

function GeometricBase.periodicity(prob::PODEProblem)
Expand Down Expand Up @@ -214,19 +223,10 @@ For possible keyword arguments see the documentation on [`EnsembleProblem`](@ref
"""
const PODEEnsemble = EnsembleProblem{PODE}

function PODEEnsemble(v, f, tspan, tstep, ics::InitialConditions;
function PODEEnsemble(v, f, tspan, tstep, ics...;
invariants = NullInvariants(),
parameters = NullParameters(),
periodicity = NullPeriodicity())
equ = PODE(v, f, invariants, parameter_types(parameters), periodicity)
EnsembleProblem(equ, tspan, tstep, ics, parameters)
end

function PODEEnsemble(v, f, tspan, tstep, q₀::ISV, p₀::ISV; kwargs...) where {ISV <: InitialStateVector}
_ics = [(q = _statevariable(q), p = _statevariable(p)) for (q,p) in zip(q₀,p₀)]
PODEEnsemble(v, f, tspan, tstep, _ics; kwargs...)
end

function PODEEnsemble(v, f, tspan, tstep, q₀::IS, p₀::IS; kwargs...) where {IS <: InitialState}
PODEEnsemble(v, f, tspan, tstep, (q = _statevariable(q₀), p = _statevariable(p₀)); kwargs...)
EnsembleProblem(equ, tspan, tstep, initialstate(equ, ics...), parameters)
end
Loading

0 comments on commit 291b4c4

Please sign in to comment.