From 79270f01b69dade7b8339358a15b0bf53d53ee8c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 Jul 2022 19:02:49 -0400 Subject: [PATCH] Fix code --- .JuliaFormatter.toml | 9 +- Project.toml | 6 +- docs/make.jl | 25 +- src/DeepEquilibriumNetworks.jl | 4 +- src/adjoint.jl | 125 ++--- src/layers/chain.jl | 50 +- src/layers/core.jl | 33 +- src/layers/deq.jl | 228 ++++---- src/layers/jacobian_stabilization.jl | 6 +- src/layers/mdeq.jl | 402 +++++++------- src/operator.jl | 8 +- src/solve.jl | 106 ++-- src/solvers/continuous.jl | 36 +- src/solvers/discrete.jl | 22 +- src/solvers/discrete/broyden.jl | 170 +++--- .../discrete/limited_memory_broyden.jl | 116 ++-- src/solvers/termination.jl | 232 ++++---- src/utils.jl | 58 +- test/runtests.jl | 524 +++++++++--------- 19 files changed, 1077 insertions(+), 1083 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index e9d91623..09035467 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,9 @@ style = "sciml" -whitespace_in_kwargs = false \ No newline at end of file +whitespace_in_kwargs = false +always_use_return = true +margin = 92 +indent = 2 +format_docstrings = true +join_lines_based_on_source = true +separate_kwargs_with_semicolon = true +always_for_in = true diff --git a/Project.toml b/Project.toml index 3a1815b2..158213f6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "0.1.1" +version = "0.1.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -9,7 +9,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" -DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -18,6 +17,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -31,13 +31,13 @@ ChainRulesCore = "1" ComponentArrays = "0.11, 0.12" DiffEqBase = "6" DiffEqCallbacks = "2.20.1" -DiffEqSensitivity = "6.64" Functors = "0.2, 0.3" LinearSolve = "1" Lux = "0.4" MLUtils = "0.2" OrdinaryDiffEq = "6" SciMLBase = "1.19" +SciMLSensitivity = "7" Setfield = "1" Static = "0.6, 0.7" SteadyStateDiffEq = "1.6" diff --git a/docs/make.jl b/docs/make.jl index 064f8e9b..980658ad 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,8 +1,8 @@ using Documenter, DocumenterCitations, DeepEquilibriumNetworks -bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"), sorting=:nyt) +bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt) -makedocs(bib, +makedocs(bib; sitename="Fast Deep Equilibrium Networks", authors="Avik Pal et al.", clean=true, @@ -10,18 +10,19 @@ makedocs(bib, modules=[DeepEquilibriumNetworks], format=Documenter.HTML(# analytics = "", # assets = ["assets/favicon.ico"], + ; canonical="https://deepequilibriumnetworks.sciml.ai/stable/"), pages=[ - "Home" => "index.md", - "Manual" => [ - "Dynamical Systems" => "manual/solvers.md", - "Non Linear Solvers" => "manual/nlsolve.md", - "General Purpose Layers" => "manual/layers.md", - "DEQ Layers" => "manual/deqs.md", - "Miscellaneous" => "manual/misc.md", - ], - "References" => "references.md", + "Home" => "index.md", + "Manual" => [ + "Dynamical Systems" => "manual/solvers.md", + "Non Linear Solvers" => "manual/nlsolve.md", + "General Purpose Layers" => "manual/layers.md", + "DEQ Layers" => "manual/deqs.md", + "Miscellaneous" => "manual/misc.md", + ], + "References" => "references.md", ]) -deploydocs(repo="github.com/SciML/DeepEquilibriumNetworks.jl.git"; +deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true) diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index d8e31be0..fef6cf75 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -5,7 +5,7 @@ using ChainRulesCore, CUDA, DiffEqBase, DiffEqCallbacks, - DiffEqSensitivity, + SciMLSensitivity, Functors, LinearAlgebra, LinearSolve, @@ -20,7 +20,7 @@ using ChainRulesCore, UnPack, Zygote -import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm +import SciMLSensitivity: AbstractAdjointSensitivityAlgorithm import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength import Random: AbstractRNG diff --git a/src/adjoint.jl b/src/adjoint.jl index cbefbd07..7b1594ef 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,83 +1,84 @@ neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x neg(nt::NamedTuple) = fmap(neg, nt) -@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, - sensealg::DeepEquilibriumAdjoint, - g::Nothing, dg; - save_idxs=nothing) - @unpack f, p, u0 = sol.prob +@noinline function SciMLSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, + sensealg::DeepEquilibriumAdjoint, + g::Nothing, dg; + save_idxs=nothing, kwargs...) + @unpack f, p, u0 = sol.prob - diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; - quad=false, needs_jac=false) + diffcache, y = SciMLSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; + quad=false, needs_jac=false) - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - if dg !== nothing - if typeof(_save_idxs) <: Number - diffcache.dg_val[_save_idxs] = dg[_save_idxs] - elseif typeof(dg) <: Number - @. diffcache.dg_val[_save_idxs] = dg - else - @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] - end - end - - if check_adjoint_mode(sensealg, Val(:vanilla)) - # Solve the Linear Problem - _val, back = Zygote.pullback(x -> f(x, p, nothing), y) - s_val = size(_val) - op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back, - s_val) - linear_problem = LinearProblem(op, vec(diffcache.dg_val)) - λ = solve(linear_problem, sensealg.linsolve).u - elseif check_adjoint_mode(sensealg, Val(:jfb)) - # Jacobian Free Backpropagation - λ = diffcache.dg_val + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + if dg !== nothing + if typeof(_save_idxs) <: Number + diffcache.dg_val[_save_idxs] = dg[_save_idxs] + elseif typeof(dg) <: Number + @. diffcache.dg_val[_save_idxs] = dg else - error("Unknown adjoint mode") + @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] end + end - # Compute the VJP - _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) - dp = back(vec(λ))[1] + if check_adjoint_mode(sensealg, Val(:vanilla)) + # Solve the Linear Problem + _val, back = Zygote.pullback(x -> f(x, p, nothing), y) + s_val = size(_val) + op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back, + s_val) + linear_problem = LinearProblem(op, vec(diffcache.dg_val)) + λ = solve(linear_problem, sensealg.linsolve).u + elseif check_adjoint_mode(sensealg, Val(:jfb)) + # Jacobian Free Backpropagation + λ = diffcache.dg_val + else + error("Unknown adjoint mode") + end - return neg(dp) + # Compute the VJP + _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) + dp = back(vec(λ))[1] + + return neg(dp) end function DiffEqBase._concrete_solve_adjoint(prob::SteadyStateProblem, alg, sensealg::DeepEquilibriumAdjoint, u0, p, - args...; save_idxs=nothing, kwargs...) - _prob = remake(prob; u0=u0, p=p) - sol = solve(_prob, alg, args...; kwargs...) - _save_idxs = save_idxs === nothing ? Colon() : save_idxs + ::SciMLBase.ADOriginator, args...; + save_idxs=nothing, kwargs...) + _prob = remake(prob; u0=u0, p=p) + sol = solve(_prob, alg, args...; kwargs...) + _save_idxs = save_idxs === nothing ? Colon() : save_idxs - out = save_idxs === nothing ? sol : - DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) + out = save_idxs === nothing ? sol : + DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) - function steadystatebackpass(Δ) - # Δ = dg/dx or diffcache.dg_val - # del g/del p = 0 - dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, - save_idxs=save_idxs) - return (NoTangent(), - NoTangent(), - NoTangent(), - NoTangent(), - dp, - NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - end + function steadystatebackpass(Δ) + # Δ = dg/dx or diffcache.dg_val + # del g/del p = 0 + dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, + save_idxs=save_idxs) + return (NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + dp, + NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end - return out, steadystatebackpass + return out, steadystatebackpass end -function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, - alg, g, dg=nothing; abstol=1e-6, - reltol=1e-3, kwargs...) - return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, + alg, g, dg=nothing; abstol=1e-6, + reltol=1e-3, kwargs...) + return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) end -function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, - alg; g=nothing, dg=nothing, abstol=1e-6, - reltol=1e-3, kwargs...) - return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, + alg; g=nothing, dg=nothing, abstol=1e-6, + reltol=1e-3, kwargs...) + return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) end diff --git a/src/layers/chain.jl b/src/layers/chain.jl index ba7c74ee..0b35a101 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -3,47 +3,47 @@ Sequence of layers divided into 3 chunks -- -* `pre_deq` -- layers that are executed before DEQ is applied -* `deq` -- The Deep Equilibrium Layer -* `post_deq` -- layers that are executed after DEQ is applied + - `pre_deq` -- layers that are executed before DEQ is applied + - `deq` -- The Deep Equilibrium Layer + - `post_deq` -- layers that are executed after DEQ is applied Constraint: Must have one DEQ layer in `layers` """ struct DEQChain{P1, D, P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)} - pre_deq::P1 - deq::D - post_deq::P2 + pre_deq::P1 + deq::D + post_deq::P2 end function DEQChain(layers...) - pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false - for l in layers - if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork - @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" - deq = l - encounter_deq = true - continue - end - push!(encounter_deq ? post_deq : pre_deq, l) + pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false + for l in layers + if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork + @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" + deq = l + encounter_deq = true + continue end - @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" - pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...) - post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...) - return DEQChain(pre_deq, deq, post_deq) + push!(encounter_deq ? post_deq : pre_deq, l) + end + @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" + pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...) + post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...) + return DEQChain(pre_deq, deq, post_deq) end function get_deq_return_type(deq::DEQChain{P1, <:Union{MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork}}, ::T) where {P1, T} - return NTuple{length(deq.deq.scales), T} + return NTuple{length(deq.deq.scales), T} end get_deq_return_type(::DEQChain, ::T) where {T} = T function (deq::DEQChain)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) - T = get_deq_return_type(deq, x) - x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq) - (x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) - x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq) - return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3) + T = get_deq_return_type(deq, x) + x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq) + (x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) + x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq) + return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3) end diff --git a/src/layers/core.jl b/src/layers/core.jl index f00bfe07..8b436a4d 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -1,15 +1,15 @@ abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), fixed_depth=Val(0)) + return (model=initialstates(rng, deq.model), fixed_depth=Val(0)) end abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model, :shortcut)} end function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), - fixed_depth=Val(0)) + return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), + fixed_depth=Val(0)) end @inline check_unrolled_mode(::Val{0})::Bool = false @@ -27,6 +27,7 @@ ChainRulesCore.@non_differentiable get_unrolled_depth(::Any) Stores the solution of a DeepEquilibriumNetwork and its variants. ## Fields + * `z_star`: Steady-State or the value reached due to maxiters * `u₀`: Initial Condition * `residual`: Difference of the ``z^*`` and ``f(z^*, x)`` @@ -34,20 +35,20 @@ Stores the solution of a DeepEquilibriumNetwork and its variants. * `nfe`: Number of Function Evaluations """ struct DeepEquilibriumSolution{T, R <: AbstractFloat} - z_star::T - u₀::T - residual::T - jacobian_loss::R - nfe::Int + z_star::T + u₀::T + residual::T + jacobian_loss::R + nfe::Int end function Base.show(io::IO, l::DeepEquilibriumSolution) - print(io, "DeepEquilibriumSolution(") - print(io, "z_star: ", l.z_star) - print(io, ", initial_condition: ", l.u₀) - print(io, ", residual: ", l.residual) - print(io, ", jacobian_loss: ", l.jacobian_loss) - print(io, ", NFE: ", l.nfe) - print(io, ")") - return nothing + print(io, "DeepEquilibriumSolution(") + print(io, "z_star: ", l.z_star) + print(io, ", initial_condition: ", l.u₀) + print(io, ", residual: ", l.residual) + print(io, ", jacobian_loss: ", l.jacobian_loss) + print(io, ", NFE: ", l.nfe) + print(io, ")") + return nothing end diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 7fc0de70..c460dfdb 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -5,23 +5,19 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) ## Arguments -* `model`: Neural Network -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` + - `model`: Neural Network + - `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) + - `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) + - `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) + - `kwargs`: Additional Parameters that are directly passed to `solve` ## Example ```julia -model = DeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) +model = DeepEquilibriumNetwork(Parallel(+, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false)), + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0)) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -32,60 +28,60 @@ model(rand(Float32, 2, 1), ps, st) See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ struct DeepEquilibriumNetwork{J, M, A, S, K} <: AbstractDeepEquilibriumNetwork - model::M - solver::A - sensealg::S - kwargs::K + model::M + solver::A + sensealg::S + kwargs::K end function DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return DeepEquilibriumNetwork{jacobian_regularization, typeof(model), typeof(solver), - typeof(sensealg), typeof(kwargs)}(model, solver, sensealg, - kwargs) + return DeepEquilibriumNetwork{jacobian_regularization, typeof(model), typeof(solver), + typeof(sensealg), typeof(kwargs)}(model, solver, sensealg, + kwargs) end function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {J, T} - z = zero(x) - - if check_unrolled_mode(st) - # Pretraining without Fixed Point Solving - st_ = st.model - z_star = z - for _ in 1:get_unrolled_depth(st) - z_star, st_ = deq.model((z_star, x), ps, st_) - end - - residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - st = merge(st, (model=st_,)) - - return (z_star, - DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), - st - end + z = zero(x) + if check_unrolled_mode(st) + # Pretraining without Fixed Point Solving st_ = st.model - - function dudt(u, p, t) - u_, st_ = deq.model((u, x), p, st_) - return u_ .- u + z_star = z + for _ in 1:get_unrolled_depth(st) + z_star, st_ = deq.model((z_star, x), ps, st_) end - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) - sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = deq.model((sol.u, x), ps, st.model) - - jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - st = merge(st, (model=st_,)) return (z_star, - DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), + DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st + end + + st_ = st.model + + function dudt(u, p, t) + u_, st_ = deq.model((u, x), p, st_) + return u_ .- u + end + + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps, st.model) + + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) + + st = merge(st, (model=st_,)) + + return (z_star, + DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), + st end """ @@ -95,26 +91,23 @@ Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) ## Arguments -* `model`: Neural Network -* `shortcut`: Shortcut for the network (pass `nothing` for SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` + - `model`: Neural Network + - `shortcut`: Shortcut for the network (pass `nothing` for SkipDEQV2) + - `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) + - `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) + - `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) + - `kwargs`: Additional Parameters that are directly passed to `solve` ## Example ```julia # SkipDEQ -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - Dense(2, 2), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) +model = SkipDeepEquilibriumNetwork(Parallel(+, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false)), + Dense(2, 2), + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, + reltol=0.01f0)) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -122,15 +115,12 @@ ps, st = Lux.setup(rng, model) model(rand(Float32, 2, 1), ps, st) # SkipDEQV2 -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - nothing, - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) +model = SkipDeepEquilibriumNetwork(Parallel(+, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false)), + nothing, + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, + reltol=0.01f0)) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -141,11 +131,11 @@ model(rand(Float32, 2, 1), ps, st) See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ struct SkipDeepEquilibriumNetwork{J, M, Sh, A, S, K} <: AbstractSkipDeepEquilibriumNetwork - model::M - shortcut::Sh - solver::A - sensealg::S - kwargs::K + model::M + shortcut::Sh + solver::A + sensealg::S + kwargs::K end function SkipDeepEquilibriumNetwork(model, @@ -154,59 +144,59 @@ function SkipDeepEquilibriumNetwork(model, jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return SkipDeepEquilibriumNetwork{ - jacobian_regularization, typeof(model), - typeof(shortcut), typeof(solver), typeof(sensealg), - typeof(kwargs) - }(model, shortcut, solver, sensealg, kwargs) + return SkipDeepEquilibriumNetwork{ + jacobian_regularization, typeof(model), + typeof(shortcut), typeof(solver), typeof(sensealg), + typeof(kwargs) + }(model, shortcut, solver, sensealg, kwargs) end function (deq::SkipDeepEquilibriumNetwork{J, M, S})(x::AbstractArray{T}, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {J, M, S, T} - z, st = if S == Nothing - z__, st__ = deq.model((zero(x), x), ps.model, st.model) - z__, merge(st, (model=st__,)) - else - z__, st__ = deq.shortcut(x, ps.shortcut, st.shortcut) - z__, merge(st, (shortcut=st__,)) + z, st = if S == Nothing + z__, st__ = deq.model((zero(x), x), ps.model, st.model) + z__, merge(st, (model=st__,)) + else + z__, st__ = deq.shortcut(x, ps.shortcut, st.shortcut) + z__, merge(st, (shortcut=st__,)) + end + + if check_unrolled_mode(st) + # Pretraining without Fixed Point Solving + st_ = st.model + z_star = z + for _ in 1:get_unrolled_depth(st) + z_star, st_ = deq.model((z_star, x), ps.model, st_) end - if check_unrolled_mode(st) - # Pretraining without Fixed Point Solving - st_ = st.model - z_star = z - for _ in 1:get_unrolled_depth(st) - z_star, st_ = deq.model((z_star, x), ps.model, st_) - end - - residual = ignore_derivatives(z_star .- - deq.model((z_star, x), ps.model, st.model)[1]) - st = merge(st, (model=st_,)) - - return (z_star, - DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), - st - end + residual = ignore_derivatives(z_star .- + deq.model((z_star, x), ps.model, st.model)[1]) + st = merge(st, (model=st_,)) - st_ = st.model + return (z_star, + DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), + st + end - function dudt(u, p, t) - u_, st_ = deq.model((u, x), p, st_) - return u_ .- u - end + st_ = st.model - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) - sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = deq.model((sol.u, x), ps.model, st.model) + function dudt(u, p, t) + u_, st_ = deq.model((u, x), p, st_) + return u_ .- u + end - jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : - T(0)) - residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps.model, st.model) - st = merge(st, (model=st_,)) + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : + T(0)) + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - return (z_star, - DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), - st + st = merge(st, (model=st_,)) + + return (z_star, + DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), + st end diff --git a/src/layers/jacobian_stabilization.jl b/src/layers/jacobian_stabilization.jl index 7fe897de..afb2c2a9 100644 --- a/src/layers/jacobian_stabilization.jl +++ b/src/layers/jacobian_stabilization.jl @@ -1,7 +1,7 @@ # Doesn't work as of now function compute_deq_jacobian_loss(model, ps::ComponentArray, st::NamedTuple, z::AbstractArray, x::AbstractArray) - l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) - vjp_z = back(gaussian_like(l))[1] - return sum(abs2, vjp_z) / length(z) + l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) + vjp_z = back(gaussian_like(l))[1] + return sum(abs2, vjp_z) / length(z) end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 12c9789f..40d2987d 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -1,11 +1,11 @@ @generated function evaluate_unrolled_mdeq(model, z_star::NTuple{N}, x, ps, st, ::Val{depth}) where {N, depth} - calls = [] - for _ in 1:depth - push!(calls, :((z_star, st) = model(((z_star[1], x), z_star[2:($N)]...), ps, st))) - end - push!(calls, :(return z_star, st)) - return Expr(:block, calls...) + calls = [] + for _ in 1:depth + push!(calls, :((z_star, st) = model(((z_star[1], x), z_star[2:($N)]...), ps, st))) + end + push!(calls, :(return z_star, st)) + return Expr(:block, calls...) end """ @@ -15,34 +15,29 @@ Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) ## Arguments -* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input -* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` -* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `scales`: Output scales -* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` + - `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input + - `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` + - `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer + - `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) + - `scales`: Output scales + - `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) + - `kwargs`: Additional Parameters that are directly passed to `solve` ## Example ```julia -model = MultiScaleDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh) - ), - [ - NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh); - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh); - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh); - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() - ], - nothing, - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), - ((4,), (3,), (2,), (1,)), -) +model = MultiScaleDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh); + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh); + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh); + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, + reltol=0.01f0), + ((4,), (3,), (2,), (1,))) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -54,18 +49,18 @@ model(x, ps, st) See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ struct MultiScaleDeepEquilibriumNetwork{N, Sc, M, A, S, K} <: AbstractDeepEquilibriumNetwork - model::M - solver::A - sensealg::S - scales::Sc - kwargs::K + model::M + solver::A + sensealg::S + scales::Sc + kwargs::K end function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), - split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), - fixed_depth=Val(0), - initial_condition=zeros(Float32, 1, 1)) + return (model=initialstates(rng, deq.model), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), + fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1)) end function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, @@ -75,31 +70,31 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, scales::NTuple{N, NTuple{L, Int64}}; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) where {N, L} - l1 = Parallel(nothing, main_layers...) - l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) - model = post_fuse_layer === nothing ? Chain(l1, l2) : - Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) - scales = static(scales) - return MultiScaleDeepEquilibriumNetwork{ - N, typeof(scales), typeof(model), - typeof(solver), typeof(sensealg), typeof(kwargs) - }(model, solver, sensealg, scales, kwargs) + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) + model = post_fuse_layer === nothing ? Chain(l1, l2) : + Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + scales = static(scales) + return MultiScaleDeepEquilibriumNetwork{ + N, typeof(scales), typeof(model), + typeof(solver), typeof(sensealg), typeof(kwargs) + }(model, solver, sensealg, scales, kwargs) end @generated function get_initial_condition_mdeq(::S, x::AbstractArray{T, N}, st::NamedTuple{fields}) where {S, T, N, fields} - scales = known(S) - sz = sum(prod.(scales)) - calls = [] - if :initial_condition ∈ fields - push!(calls, :(u0 = st[:initial_condition])) - push!(calls, :(($sz, size(x, $N)) == size(u0) && return u0, st)) - end - push!(calls, :(u0 = fill!(similar(x, $(sz), size(x, N)), $(T(0))))) - push!(calls, :(st = merge(st, (initial_condition=u0,))::typeof(st))) - push!(calls, :(return u0, st)) - return Expr(:block, calls...) + scales = known(S) + sz = sum(prod.(scales)) + calls = [] + if :initial_condition ∈ fields + push!(calls, :(u0 = st[:initial_condition])) + push!(calls, :(($sz, size(x, $N)) == size(u0) && return u0, st)) + end + push!(calls, :(u0 = fill!(similar(x, $(sz), size(x, N)), $(T(0))))) + push!(calls, :(st = merge(st, (initial_condition=u0,))::typeof(st))) + push!(calls, :(return u0, st)) + return Expr(:block, calls...) end ChainRulesCore.@non_differentiable get_initial_condition_mdeq(::Any...) @@ -107,46 +102,46 @@ ChainRulesCore.@non_differentiable get_initial_condition_mdeq(::Any...) function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {N, T} - z, st = get_initial_condition_mdeq(deq.scales, x, st) + z, st = get_initial_condition_mdeq(deq.scales, x, st) - if check_unrolled_mode(st) - z_star = split_and_reshape(z, st.split_idxs, deq.scales) - z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st.model, - st.fixed_depth) + if check_unrolled_mode(st) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st.model, + st.fixed_depth) - residual = ignore_derivatives(vcat(flatten.(z_star)...) .- - vcat(flatten.(evaluate_unrolled_mdeq(deq.model, - z_star, x, ps, - st_, Val(1))[1])...)) - st__ = merge(st, (model=st_,)) + residual = ignore_derivatives(vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, + z_star, x, ps, + st_, Val(1))[1])...)) + st__ = merge(st, (model=st_,)) - return ((z_star, - DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, - get_unrolled_depth(st))), - st__) - end + return ((z_star, + DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, + get_unrolled_depth(st))), + st__) + end - st_ = st.model + st_ = st.model - function dudt_(u, p, t) - u_split = split_and_reshape(u, st.split_idxs, deq.scales) - u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) - return u_, st_ - end + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) + return u_, st_ + end - dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) - sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps, nothing) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps, nothing) - residual = ignore_derivatives(dudt(sol.u, ps, nothing)) + residual = ignore_derivatives(dudt(sol.u, ps, nothing)) - st__ = merge(st, (model=st_,)) + st__ = merge(st, (model=st_,)) - return ((z_star, - DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, - sol.destats.nf + 1)), st__) + return ((z_star, + DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, + sol.destats.nf + 1)), st__) end """ @@ -156,38 +151,38 @@ Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) co ## Arguments -* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input -* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` -* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer -* `shortcut_layers`: Shortcut for the network (pass `nothing` for SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `scales`: Output scales -* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` + - `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input + - `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` + - `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer + - `shortcut_layers`: Shortcut for the network (pass `nothing` for SkipDEQV2) + - `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) + - `scales`: Output scales + - `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) + - `kwargs`: Additional Parameters that are directly passed to `solve` ## Example ```julia # MSkipDEQ -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh), - ), - [ - NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() - ], - nothing, - (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), - ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), - ((4,), (3,), (2,), (1,)); - sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), -) +model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), + Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + (Dense(4, 4, tanh), Dense(4, 3, tanh), + Dense(4, 2, tanh), Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, + 10)) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -196,25 +191,24 @@ x = rand(rng, Float32, 4, 2) model(x, ps, st) # MSkipDEQV2 -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh), - ), - [ - NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() - ], - nothing, - nothing, - ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), - ((4,), (3,), (2,), (1,)); - sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), - ) +model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), + Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, + 10)) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -227,20 +221,20 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref) """ struct MultiScaleSkipDeepEquilibriumNetwork{N, Sc, M, Sh, A, S, K} <: AbstractSkipDeepEquilibriumNetwork - model::M - shortcut::Sh - solver::A - sensealg::S - scales::Sc - kwargs::K + model::M + shortcut::Sh + solver::A + sensealg::S + scales::Sc + kwargs::K end function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), - shortcut=initialstates(rng, deq.shortcut), - split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), - fixed_depth=Val(0), - initial_condition=zeros(Float32, 1, 1)) + return (model=initialstates(rng, deq.model), + shortcut=initialstates(rng, deq.shortcut), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), + fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1)) end function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, @@ -252,19 +246,19 @@ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) - l1 = Parallel(nothing, main_layers...) - l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) - model = post_fuse_layer === nothing ? Chain(l1, l2) : - Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) - shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...) - scales = static(scales) - return MultiScaleSkipDeepEquilibriumNetwork{ - length(scales), typeof(scales), - typeof(model), typeof(shortcut), - typeof(solver), typeof(sensealg), - typeof(kwargs) - }(model, shortcut, solver, sensealg, scales, - kwargs) + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) + model = post_fuse_layer === nothing ? Chain(l1, l2) : + Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...) + scales = static(scales) + return MultiScaleSkipDeepEquilibriumNetwork{ + length(scales), typeof(scales), + typeof(model), typeof(shortcut), + typeof(solver), typeof(sensealg), + typeof(kwargs) + }(model, shortcut, solver, sensealg, scales, + kwargs) end function (deq::MultiScaleSkipDeepEquilibriumNetwork{N, Sc, M, Sh})(x::AbstractArray{T}, @@ -275,53 +269,53 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N, Sc, M, Sh})(x::AbstractAr M, Sh, T} - z, st = if Sh == Nothing - u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) - u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) - z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) - (vcat(flatten.(z0)...), merge(st_, (model=st__,))) - else - z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) - (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))) - end - - if check_unrolled_mode(st) - z_star = split_and_reshape(z, st.split_idxs, deq.scales) - z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st.model, - st.fixed_depth) - - residual = ignore_derivatives(vcat(flatten.(z_star)...) .- - vcat(flatten.(evaluate_unrolled_mdeq(deq.model, - z_star, x, - ps.model, st_, - Val(1))[1])...)) - st__ = merge(st, (model=st_,)) - - return ((z_star, - DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, - get_unrolled_depth(st))), - st__) - end - - st_ = st.model - - function dudt_(u, p, t) - u_split = split_and_reshape(u, st.split_idxs, deq.scales) - u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) - return u_, st_ - end - - dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u - - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) - sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps.model, nothing) - - residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) - + z, st = if Sh == Nothing + u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) + u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) + z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) + (vcat(flatten.(z0)...), merge(st_, (model=st__,))) + else + z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) + (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))) + end + + if check_unrolled_mode(st) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st.model, + st.fixed_depth) + + residual = ignore_derivatives(vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, + z_star, x, + ps.model, st_, + Val(1))[1])...)) st__ = merge(st, (model=st_,)) return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, - sol.destats.nf + 1)), st) + get_unrolled_depth(st))), + st__) + end + + st_ = st.model + + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) + return u_, st_ + end + + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u + + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps.model, nothing) + + residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) + + st__ = merge(st, (model=st_,)) + + return ((z_star, + DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, + sol.destats.nf + 1)), st) end diff --git a/src/operator.jl b/src/operator.jl index 4fe41f85..1b3ee567 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -1,6 +1,6 @@ struct ZygotePullbackMultiplyOperator{T, F, S} - f::F - s::S + f::F + s::S end Base.deepcopy(op::ZygotePullbackMultiplyOperator) = op @@ -13,11 +13,11 @@ Base.eltype(::ZygotePullbackMultiplyOperator{T}) where {T} = T function LinearAlgebra.mul!(du::AbstractVector, L::ZygotePullbackMultiplyOperator, x::AbstractVector) - du .= vec(L * x) + return du .= vec(L * x) end function Base.:*(L::ZygotePullbackMultiplyOperator, x::AbstractVector) - return L.f(reshape(x, L.s))[1] + return L.f(reshape(x, L.s))[1] end SciMLBase.isinplace(z::ZygotePullbackMultiplyOperator, ::Int64) = false diff --git a/src/solve.jl b/src/solve.jl index 4599fac8..8448cf9f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,80 +1,80 @@ struct EquilibriumSolution{T, N, uType, P, A, D} <: SciMLBase.AbstractNonlinearSolution{T, N} - u::uType - resid::uType - prob::P - alg::A - retcode::Symbol - destats::D + u::uType + resid::uType + prob::P + alg::A + retcode::Symbol + destats::D end function transform_solution(soln::EquilibriumSolution) - # Creates a NonlinearSolution/SteadyStateSolution - return DiffEqBase.build_solution(soln.prob, soln.alg, soln.u, soln.resid; - retcode=soln.retcode) + # Creates a NonlinearSolution/SteadyStateSolution + return DiffEqBase.build_solution(soln.prob, soln.alg, soln.u, soln.resid; + retcode=soln.retcode) end function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::ContinuousDEQSolver, args...; kwargs...) where {uType} - tspan = alg.tspan isa Tuple ? alg.tspan : - convert.(real(eltype(prob.u0)), (zero(alg.tspan), alg.tspan)) - _prob = ODEProblem(prob.f, prob.u0, tspan, prob.p) + tspan = alg.tspan isa Tuple ? alg.tspan : + convert.(real(eltype(prob.u0)), (zero(alg.tspan), alg.tspan)) + _prob = ODEProblem(prob.f, prob.u0, tspan, prob.p) - terminate_stats = Dict{Symbol, Any}(:best_objective_value => real(eltype(prob.u0))(Inf), - :best_objective_value_iteration => nothing) + terminate_stats = Dict{Symbol, Any}(:best_objective_value => real(eltype(prob.u0))(Inf), + :best_objective_value_iteration => nothing) - sol = solve(_prob, - alg.alg, - args...; - kwargs..., - callback=TerminateSteadyState(alg.abstol_termination, - alg.reltol_termination, - get_terminate_condition(alg, terminate_stats))) + sol = solve(_prob, + alg.alg, + args...; + kwargs..., + callback=TerminateSteadyState(alg.abstol_termination, + alg.reltol_termination, + get_terminate_condition(alg, terminate_stats))) - u, t = if terminate_stats[:best_objective_value_iteration] === nothing - (sol.u[end], sol.t[end]) - else - (sol.u[terminate_stats[:best_objective_value_iteration] + 1], - sol.t[terminate_stats[:best_objective_value_iteration] + 1]) - end + u, t = if terminate_stats[:best_objective_value_iteration] === nothing + (sol.u[end], sol.t[end]) + else + (sol.u[terminate_stats[:best_objective_value_iteration] + 1], + sol.t[terminate_stats[:best_objective_value_iteration] + 1]) + end - # Dont count towards NFE since this is mostly a check for convergence - du = prob.f(u, prob.p, t) + # Dont count towards NFE since this is mostly a check for convergence + du = prob.f(u, prob.p, t) - retcode = (sol.retcode == :Terminated && has_converged(du, u, alg) ? :Success : - :Failure) + retcode = (sol.retcode == :Terminated && has_converged(du, u, alg) ? :Success : + :Failure) - return EquilibriumSolution{eltype(uType), ndims(uType), uType, typeof(prob), - typeof(alg), typeof(sol.destats)}(u, du, prob, alg, retcode, - sol.destats) + return EquilibriumSolution{eltype(uType), ndims(uType), uType, typeof(prob), + typeof(alg), typeof(sol.destats)}(u, du, prob, alg, retcode, + sol.destats) end function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::DiscreteDEQSolver, args...; maxiters=10, kwargs...) where {uType} - terminate_stats = Dict{Symbol, Any}(:best_objective_value => real(eltype(prob.u0))(Inf), - :best_objective_value_iteration => nothing) + terminate_stats = Dict{Symbol, Any}(:best_objective_value => real(eltype(prob.u0))(Inf), + :best_objective_value_iteration => nothing) - us, stats = nlsolve(alg.alg, - u -> prob.f(u, prob.p, nothing), - prob.u0; - maxiters=maxiters, - terminate_condition=get_terminate_condition(alg, terminate_stats)) + us, stats = nlsolve(alg.alg, + u -> prob.f(u, prob.p, nothing), + prob.u0; + maxiters=maxiters, + terminate_condition=get_terminate_condition(alg, terminate_stats)) - u = if terminate_stats[:best_objective_value_iteration] === nothing - us[end] - else - us[terminate_stats[:best_objective_value_iteration] + 1] - end + u = if terminate_stats[:best_objective_value_iteration] === nothing + us[end] + else + us[terminate_stats[:best_objective_value_iteration] + 1] + end - # Dont count towards NFE since this is mostly a check for convergence - du = prob.f(u, prob.p, nothing) + # Dont count towards NFE since this is mostly a check for convergence + du = prob.f(u, prob.p, nothing) - retcode = has_converged(du, u, alg) ? :Success : :Failure + retcode = has_converged(du, u, alg) ? :Success : :Failure - destats = (nf=stats.nf,) + destats = (nf=stats.nf,) - return EquilibriumSolution{eltype(uType), ndims(uType), uType, typeof(prob), - typeof(alg), typeof(destats)}(u, du, prob, alg, retcode, - destats) + return EquilibriumSolution{eltype(uType), ndims(uType), uType, typeof(prob), + typeof(alg), typeof(destats)}(u, du, prob, alg, retcode, + destats) end diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl index 7edb7fd9..84424728 100644 --- a/src/solvers/continuous.jl +++ b/src/solvers/continuous.jl @@ -6,23 +6,23 @@ for solving DEQ problems. ## Arguments -* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM3()`) -* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) -* `abstol`: Absolute tolerance for time stepping. (Default: `1f-8`) -* `reltol`: Relative tolerance for time stepping. (Default: `1f-8`) -* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) -* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) -* `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf32`) + - `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM3()`) + - `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) + - `abstol`: Absolute tolerance for time stepping. (Default: `1f-8`) + - `reltol`: Relative tolerance for time stepping. (Default: `1f-8`) + - `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) + - `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) + - `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf32`) See also: [`DiscreteDEQSolver`](@ref) """ struct ContinuousDEQSolver{M, A, T, TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm - alg::A - abstol::T - reltol::T - abstol_termination::T - reltol_termination::T - tspan::TS + alg::A + abstol::T + reltol::T + abstol_termination::T + reltol_termination::T + tspan::TS end function ContinuousDEQSolver(alg=VCABM3(); @@ -32,9 +32,9 @@ function ContinuousDEQSolver(alg=VCABM3(); abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8, tspan=Inf32) where {T <: Number} - return ContinuousDEQSolver{Val(mode), typeof(alg), T, typeof(tspan)}(alg, abstol, - reltol, - abstol_termination, - reltol_termination, - tspan) + return ContinuousDEQSolver{Val(mode), typeof(alg), T, typeof(tspan)}(alg, abstol, + reltol, + abstol_termination, + reltol_termination, + tspan) end diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl index 6de8829f..fd33f4c1 100644 --- a/src/solvers/discrete.jl +++ b/src/solvers/discrete.jl @@ -2,29 +2,29 @@ DiscreteDEQSolver(alg=LimitedMemoryBroydenSolver(); mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8) Solver for Discrete DEQ Problem ([baideep2019](@cite)). Similar to `SSrootfind` but provides more flexibility needed - for solving DEQ problems. - +for solving DEQ problems. + ## Arguments -* `alg`: Algorithm to solve the Nonlinear Problem (Default: [`LimitedMemoryBroydenSolver`](@ref)) -* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) -* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) -* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) + - `alg`: Algorithm to solve the Nonlinear Problem (Default: [`LimitedMemoryBroydenSolver`](@ref)) + - `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) + - `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) + - `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) See also: [`ContinuousDEQSolver`](@ref) """ struct DiscreteDEQSolver{M, A, T} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm - alg::A - abstol_termination::T - reltol_termination::T + alg::A + abstol_termination::T + reltol_termination::T end function DiscreteDEQSolver(alg=LimitedMemoryBroydenSolver(); mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8) where {T <: Number} - return DiscreteDEQSolver{Val(mode), typeof(alg), T}(alg, abstol_termination, - reltol_termination) + return DiscreteDEQSolver{Val(mode), typeof(alg), T}(alg, abstol_termination, + reltol_termination) end include("discrete/broyden.jl") diff --git a/src/solvers/discrete/broyden.jl b/src/solvers/discrete/broyden.jl index 0960acf0..57a20337 100644 --- a/src/solvers/discrete/broyden.jl +++ b/src/solvers/discrete/broyden.jl @@ -6,15 +6,15 @@ Broyden Solver ([broyden1965class](@cite)) for solving Discrete DEQs. It is reco ## Arguments -* `T`: The type of the elements of the vectors. (Default: `Float32`) -* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. -* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). -* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having - them match allows us to efficiently cache internal statistics without reallocation. -* `maxiters`: Maximum number of iterations to run. -* `ϵ`: Tolerance for convergence. -* `abstol`: Absolute tolerance. -* `reltol`: Relative tolerance. (This value is ignored by `BroydenSolver` at the moment) + - `T`: The type of the elements of the vectors. (Default: `Float32`) + - `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. + - `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). + - `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having + them match allows us to efficiently cache internal statistics without reallocation. + - `maxiters`: Maximum number of iterations to run. + - `ϵ`: Tolerance for convergence. + - `abstol`: Absolute tolerance. + - `reltol`: Relative tolerance. (This value is ignored by `BroydenSolver` at the moment) See also: [`LimitedMemoryBroydenSolver`](@ref) """ @@ -22,96 +22,96 @@ struct BroydenSolver end function nlsolve(b::BroydenSolver, f::Function, y::AbstractArray{T}; terminate_condition, maxiters::Int=10) where {T} - res, stats = nlsolve(b, - u -> vec(f(reshape(u, size(y)))), - vec(y); - terminate_condition, - maxiters) - return reshape.(res, (size(y),)), stats + res, stats = nlsolve(b, + u -> vec(f(reshape(u, size(y)))), + vec(y); + terminate_condition, + maxiters) + return reshape.(res, (size(y),)), stats end function nlsolve(::BroydenSolver, f::Function, y::AbstractVector{T}; terminate_condition, maxiters::Int=10) where {T} - x = copy(y) - x_old = copy(y) - Δx = copy(y) - fx_old = f(y) - Δfx = copy(fx_old) - Jinv = _init_identity_matrix(y) - p = similar(fx_old, (size(Jinv, 1),)) - ρ, σ₂ = T(0.9), T(0.001) - - # Store the trajectory - xs = [x] - - maybe_stuck, max_resets, resets, nsteps, nf = false, 3, 0, 1, 1 - - while nsteps <= maxiters - mul!(p, Jinv, fx_old) - p .*= -1 - - @. x = x_old + p - fx = f(x) - nf += 1 - - if norm(fx, 2) ≤ ρ * norm(fx_old, 2) - σ₂ * norm(p, 2)^2 - α = T(1) - else - α, _stats = _approximate_norm_descent(f, x, p) - @. x = x_old + α * p - fx = f(x) - nf += 1 + _stats.nf - end - - @. Δx = x - x_old - @. Δfx = fx - fx_old - - maybe_stuck = all(abs.(Δx) .<= eps(T)) || all(abs.(Δfx) .<= eps(T)) - if maybe_stuck - Jinv = _init_identity_matrix(x) - resets += 1 - maybe_stuck = (resets ≤ max_resets) && maybe_stuck - else - ΔxJinv = Δx' * Jinv - Jinv .+= ((Δx .- Jinv * Δfx) ./ (ΔxJinv * Δfx)) * ΔxJinv - end - - maybe_stuck = false - nsteps += 1 - copyto!(fx_old, fx) - copyto!(x_old, x) - - push!(xs, x) - - # Convergence Check - terminate_condition(fx, x) && break + x = copy(y) + x_old = copy(y) + Δx = copy(y) + fx_old = f(y) + Δfx = copy(fx_old) + Jinv = _init_identity_matrix(y) + p = similar(fx_old, (size(Jinv, 1),)) + ρ, σ₂ = T(0.9), T(0.001) + + # Store the trajectory + xs = [x] + + maybe_stuck, max_resets, resets, nsteps, nf = false, 3, 0, 1, 1 + + while nsteps <= maxiters + mul!(p, Jinv, fx_old) + p .*= -1 + + @. x = x_old + p + fx = f(x) + nf += 1 + + if norm(fx, 2) ≤ ρ * norm(fx_old, 2) - σ₂ * norm(p, 2)^2 + α = T(1) + else + α, _stats = _approximate_norm_descent(f, x, p) + @. x = x_old + α * p + fx = f(x) + nf += 1 + _stats.nf + end + + @. Δx = x - x_old + @. Δfx = fx - fx_old + + maybe_stuck = all(abs.(Δx) .<= eps(T)) || all(abs.(Δfx) .<= eps(T)) + if maybe_stuck + Jinv = _init_identity_matrix(x) + resets += 1 + maybe_stuck = (resets ≤ max_resets) && maybe_stuck + else + ΔxJinv = Δx' * Jinv + Jinv .+= ((Δx .- Jinv * Δfx) ./ (ΔxJinv * Δfx)) * ΔxJinv end - return xs, (nf=nf,) + maybe_stuck = false + nsteps += 1 + copyto!(fx_old, fx) + copyto!(x_old, x) + + push!(xs, x) + + # Convergence Check + terminate_condition(fx, x) && break + end + + return xs, (nf=nf,) end function _approximate_norm_descent(f::Function, x::AbstractArray{T, N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), η=T(0.1), max_iter=50) where {T, N} - λ₂, λ₁ = λ₀, λ₀ + λ₂, λ₁ = λ₀, λ₀ - fx = f(x) - fx_norm = norm(fx, 2) - j = 1 - fx = f(x .+ λ₂ .* p) - converged = false - - while j <= max_iter && !converged - j += 1 - λ₁, λ₂ = λ₂, β * λ₂ - converged = _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) - end + fx = f(x) + fx_norm = norm(fx, 2) + j = 1 + fx = f(x .+ λ₂ .* p) + converged = false + + while j <= max_iter && !converged + j += 1 + λ₁, λ₂ = λ₂, β * λ₂ + converged = _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) + end - return λ₂, (nf=2(j + 1),) + return λ₂, (nf=2(j + 1),) end function _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) - n1 = norm(f(x .+ λ₂ .* p), 2) - n2 = norm(f(x), 2) - return n1 ≤ fx_norm - σ₁ * norm(λ₂ .* p, 2) .^ 2 + η * n2 + n1 = norm(f(x .+ λ₂ .* p), 2) + n2 = norm(f(x), 2) + return n1 ≤ fx_norm - σ₁ * norm(λ₂ .* p, 2) .^ 2 + η * n2 end diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl index 94c0487f..7633d298 100644 --- a/src/solvers/discrete/limited_memory_broyden.jl +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -8,16 +8,16 @@ Limited Memory Broyden Solver ([baimultiscale2020](@cite)) for solving Discrete ## Arguments -* `T`: The type of the elements of the vectors. (Default: `Float32`) -* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. -* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). -* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having - them match allows us to efficiently cache internal statistics without reallocation. -* `maxiters`: Maximum number of iterations to run. -* `ϵ`: Tolerance for convergence. -* `criteria`: The criteria to use for convergence. Can be `:reltol` or `:abstol`. -* `abstol`: Absolute tolerance. -* `reltol`: Relative tolerance. + - `T`: The type of the elements of the vectors. (Default: `Float32`) + - `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. + - `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). + - `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having + them match allows us to efficiently cache internal statistics without reallocation. + - `maxiters`: Maximum number of iterations to run. + - `ϵ`: Tolerance for convergence. + - `criteria`: The criteria to use for convergence. Can be `:reltol` or `:abstol`. + - `abstol`: Absolute tolerance. + - `reltol`: Relative tolerance. See also: [`BroydenSolver`](@ref) """ @@ -26,76 +26,76 @@ struct LimitedMemoryBroydenSolver end @inbounds @views function nlsolve(::LimitedMemoryBroydenSolver, f::Function, y::AbstractMatrix{T}; terminate_condition, maxiters::Int=10) where {T} - LBFGS_threshold = min(maxiters, 27) + LBFGS_threshold = min(maxiters, 27) - total_hsize, batch_size = size(y) + total_hsize, batch_size = size(y) - # Initialize the cache - x₀ = copy(y) - fx₀ = f(x₀) - x₁ = copy(y) - Δx = copy(x₀) - Δfx = copy(x₀) - Us = fill!(similar(y, (LBFGS_threshold, total_hsize, batch_size)), T(0)) - VTs = fill!(similar(y, (total_hsize, LBFGS_threshold, batch_size)), T(0)) + # Initialize the cache + x₀ = copy(y) + fx₀ = f(x₀) + x₁ = copy(y) + Δx = copy(x₀) + Δfx = copy(x₀) + Us = fill!(similar(y, (LBFGS_threshold, total_hsize, batch_size)), T(0)) + VTs = fill!(similar(y, (total_hsize, LBFGS_threshold, batch_size)), T(0)) - # Store the trajectory - xs = [x₀] + # Store the trajectory + xs = [x₀] - # Counters - nstep = 1 + # Counters + nstep = 1 - # Main Algorithm - update = fx₀ + # Main Algorithm + update = fx₀ - while nstep <= maxiters - # Update - @. x₁ = x₀ + update - fx₁ = f(x₁) - @. Δx = x₁ - x₀ - @. Δfx = fx₁ - fx₀ + while nstep <= maxiters + # Update + @. x₁ = x₀ + update + fx₁ = f(x₁) + @. Δx = x₁ - x₀ + @. Δfx = fx₁ - fx₀ - push!(xs, x₁) + push!(xs, x₁) - # Convergence Check - terminate_condition(fx₁, x₁) && break + # Convergence Check + terminate_condition(fx₁, x₁) && break - # Compute the update - part_Us = Us[1:min(LBFGS_threshold, nstep), :, :] - part_VTs = VTs[:, 1:min(LBFGS_threshold, nstep), :] + # Compute the update + part_Us = Us[1:min(LBFGS_threshold, nstep), :, :] + part_VTs = VTs[:, 1:min(LBFGS_threshold, nstep), :] - vT = rmatvec(part_Us, part_VTs, Δx) # D x C x N - mvec = matvec(part_Us, part_VTs, Δfx) - vTΔfx = sum(vT .* Δfx; dims=(1, 2)) - @. Δx = (Δx - mvec) / (vTΔfx + eps(T)) # D x C x N + vT = rmatvec(part_Us, part_VTs, Δx) # D x C x N + mvec = matvec(part_Us, part_VTs, Δfx) + vTΔfx = sum(vT .* Δfx; dims=(1, 2)) + @. Δx = (Δx - mvec) / (vTΔfx + eps(T)) # D x C x N - VTs[:, mod1(nstep, LBFGS_threshold), :] .= vT - Us[mod1(nstep, LBFGS_threshold), :, :] .= Δx + VTs[:, mod1(nstep, LBFGS_threshold), :] .= vT + Us[mod1(nstep, LBFGS_threshold), :, :] .= Δx - update = -matvec(Us[1:min(LBFGS_threshold, nstep + 1), :, :], - VTs[:, 1:min(LBFGS_threshold, nstep + 1), :], fx₁) - copyto!(x₀, x₁) - copyto!(fx₀, fx₁) + update = -matvec(Us[1:min(LBFGS_threshold, nstep + 1), :, :], + VTs[:, 1:min(LBFGS_threshold, nstep + 1), :], fx₁) + copyto!(x₀, x₁) + copyto!(fx₀, fx₁) - # Increment Counter - nstep += 1 - end + # Increment Counter + nstep += 1 + end - return xs, (nf=nstep + 1,) + return xs, (nf=nstep + 1,) end @inbounds @views function matvec(part_Us::AbstractArray{E, 3}, part_VTs::AbstractArray{E, 3}, x::AbstractArray{E, 2}) where {E} - # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) - xTU = sum(unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N - return -x .+ dropdims(sum(permutedims(xTU, (2, 1, 3)) .* part_VTs; dims=2); dims=2) + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + xTU = sum(unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N + return -x .+ dropdims(sum(permutedims(xTU, (2, 1, 3)) .* part_VTs; dims=2); dims=2) end @inbounds @views function rmatvec(part_Us::AbstractArray{E, 3}, part_VTs::AbstractArray{E, 3}, x::AbstractArray{E, 2}) where {E} - # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) - VTx = sum(part_VTs .* unsqueeze(x; dims=2); dims=1) # 1 x T x N - return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (2, 1, 3)); dims=1); dims=1) + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + VTx = sum(part_VTs .* unsqueeze(x; dims=2); dims=1) # 1 x T x N + return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (2, 1, 3)); dims=1); dims=1) end diff --git a/src/solvers/termination.jl b/src/solvers/termination.jl index e200b58e..0d384e65 100644 --- a/src/solvers/termination.jl +++ b/src/solvers/termination.jl @@ -2,113 +2,113 @@ get_mode(::Val{mode}) where {mode} = mode function get_terminate_condition(alg::ContinuousDEQSolver{M, A, T}, args...; kwargs...) where {M, A, T} - mode = get_mode(M) - if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) - nstep, protective_threshold, objective_values = 0, T(1e3), T[] + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] - if mode ∈ (:rel_deq_best, :abs_deq_best) - @assert length(args) == 1 + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 - args[1][:best_objective_value] = T(Inf) - args[1][:best_objective_value_iteration] = 0 - end + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end - function terminate_condition_closure_1(integrator, abstol, reltol, min_t) - du, u = DiffEqBase.get_du(integrator), integrator.u - objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : - (norm(du .+ u) + eps(T))) - criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol - - if mode ∈ (:rel_deq_best, :abs_deq_best) - if objective < args[1][:best_objective_value] - args[1][:best_objective_value] = objective - args[1][:best_objective_value_iteration] = nstep + 1 - end - end - - # Main Termination Criteria - objective <= criteria && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * criteria && - nstep >= 30 && - maximum(objective_values[max(1, length(objective_values) - nstep):end]) < - 1.3 * - minimum(objective_values[max(1, length(objective_values) - nstep):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && - return true - - return false - end - return terminate_condition_closure_1 - else - function terminate_condition_closure_2(integrator, abstol, reltol, min_t) - return has_converged(DiffEqBase.get_du(integrator), integrator.u, M, abstol, - reltol) + function terminate_condition_closure_1(integrator, abstol, reltol, min_t) + du, u = DiffEqBase.get_du(integrator), integrator.u + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : + (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 end - return terminate_condition_closure_2 + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * + minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && + return true + + return false end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(integrator, abstol, reltol, min_t) + return has_converged(DiffEqBase.get_du(integrator), integrator.u, M, abstol, + reltol) + end + return terminate_condition_closure_2 + end end function get_terminate_condition(alg::DiscreteDEQSolver{M, A, T}, args...; kwargs...) where {M, A, T} - mode = get_mode(M) - if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) - nstep, protective_threshold, objective_values = 0, T(1e3), T[] + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] - if mode ∈ (:rel_deq_best, :abs_deq_best) - @assert length(args) == 1 + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 - args[1][:best_objective_value] = T(Inf) - args[1][:best_objective_value_iteration] = 0 - end + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end - function terminate_condition_closure_1(du, u) - objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : - (norm(du .+ u) + eps(T))) - criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? alg.abstol_termination : - alg.reltol_termination - - if mode ∈ (:rel_deq_best, :abs_deq_best) - if objective < args[1][:best_objective_value] - args[1][:best_objective_value] = objective - args[1][:best_objective_value_iteration] = nstep + 1 - end - end - - # Main Termination Criteria - objective <= criteria && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * criteria && - nstep >= 30 && - maximum(objective_values[max(1, length(objective_values) - nstep):end]) < - 1.3 * - minimum(objective_values[max(1, length(objective_values) - nstep):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && - return true - - return false - end - return terminate_condition_closure_1 - else - function terminate_condition_closure_2(du, u) - return has_converged(du, u, M, alg.abstol_termination, alg.reltol_termination) + function terminate_condition_closure_1(du, u) + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : + (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? alg.abstol_termination : + alg.reltol_termination + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 end - return terminate_condition_closure_2 + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * + minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && + return true + + return false end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(du, u) + return has_converged(du, u, M, alg.abstol_termination, alg.reltol_termination) + end + return terminate_condition_closure_2 + end end # Convergence Criterions @@ -117,30 +117,30 @@ end alg::Union{ContinuousDEQSolver{M}, DiscreteDEQSolver{M}}, abstol=alg.abstol_termination, reltol=alg.reltol_termination) where {M} - return has_converged(du, u, M, abstol, reltol) + return has_converged(du, u, M, abstol, reltol) end @inline @inbounds function has_converged(du, u, M, abstol, reltol) - mode = get_mode(M) - if mode == :norm - return norm(du) <= abstol && norm(du) <= reltol * norm(du + u) - elseif mode == :rel - return all(abs.(du) .<= reltol .* abs.(u)) - elseif mode == :rel_norm - return norm(du) <= reltol * norm(du + u) - elseif mode == :rel_deq_default - return norm(du) <= reltol * norm(du + u) - elseif mode == :rel_deq_best - return norm(du) <= reltol * norm(du + u) - elseif mode == :abs - return all(abs.(du) .<= abstol) - elseif mode == :abs_norm - return norm(du) <= abstol - elseif mode == :abs_deq_default - return norm(du) <= abstol - elseif mode == :abs_deq_best - return norm(du) <= abstol - else - return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) - end + mode = get_mode(M) + if mode == :norm + return norm(du) <= abstol && norm(du) <= reltol * norm(du + u) + elseif mode == :rel + return all(abs.(du) .<= reltol .* abs.(u)) + elseif mode == :rel_norm + return norm(du) <= reltol * norm(du + u) + elseif mode == :rel_deq_default + return norm(du) <= reltol * norm(du + u) + elseif mode == :rel_deq_best + return norm(du) <= reltol * norm(du + u) + elseif mode == :abs + return all(abs.(du) .<= abstol) + elseif mode == :abs_norm + return norm(du) <= abstol + elseif mode == :abs_deq_default + return norm(du) <= abstol + elseif mode == :abs_deq_best + return norm(du) <= abstol + else + return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) + end end diff --git a/src/utils.jl b/src/utils.jl index cfea79e5..a907ac5d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,22 +8,22 @@ Creates DeepEquilibriumAdjoint ([johnson2012notes](@cite)) with sensible default ## Arguments -* `reltol`: Relative tolerance. -* `abstol`: Absolute tolerance. -* `maxiters`: Maximum number of iterations. -* `autojacvec`: Which backend to use for VJP. -* `linsolve`: Linear Solver from [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl). -* `mode`: Adjoint mode. Currently only `:vanilla` & `:jfb` are supported. + - `reltol`: Relative tolerance. + - `abstol`: Absolute tolerance. + - `maxiters`: Maximum number of iterations. + - `autojacvec`: Which backend to use for VJP. + - `linsolve`: Linear Solver from [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl). + - `mode`: Adjoint mode. Currently only `:vanilla` & `:jfb` are supported. """ struct DeepEquilibriumAdjoint{CS, AD, FDT, M, VJP, LS} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} - autojacvec::VJP - linsolve::LS + autojacvec::VJP + linsolve::LS end @inline function check_adjoint_mode(::DeepEquilibriumAdjoint{CS, AD, FDT, M}, ::Val{M}) where {CS, AD, FDT, M} - true + return true end @inline check_adjoint_mode(::DeepEquilibriumAdjoint, ::Val) = false @@ -38,10 +38,10 @@ Base.@pure function DeepEquilibriumAdjoint(reltol, chunk_size=0, diff_type=Val{:central}, mode::Symbol=:vanilla) - return DeepEquilibriumAdjoint{ - chunk_size, autodiff, diff_type, mode, typeof(autojacvec), - typeof(linsolve) - }(autojacvec, linsolve) + return DeepEquilibriumAdjoint{ + chunk_size, autodiff, diff_type, mode, typeof(autojacvec), + typeof(linsolve) + }(autojacvec, linsolve) end # Initialization @@ -52,37 +52,37 @@ Initializes the weights of the network with a normal distribution. For DEQs the if we use this as the Initialization """ function NormalInitializer(μ=0.0f0, σ²=0.01f0) - return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ + return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end # For MultiScale DEQs @generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T, S} - idxs, shapes = known(T), known(S) - dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] - varnames = [gensym("x_view") for _ in dims] - calls = [] - for (i, dim) in enumerate(dims) - push!(calls, :($(varnames[i]) = view(x, $dim, :))) - end - push!(calls, :(return tuple($(Tuple(varnames)...)))) - return Expr(:block, calls...) + idxs, shapes = known(T), known(S) + dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] + varnames = [gensym("x_view") for _ in dims] + calls = [] + for (i, dim) in enumerate(dims) + push!(calls, :($(varnames[i]) = view(x, $dim, :))) + end + push!(calls, :(return tuple($(Tuple(varnames)...)))) + return Expr(:block, calls...) end # General Utils @inline function _init_identity_matrix(x::AbstractArray{T}, scale::T=T(1)) where {T} - x_ = vec(x) - return _init_identity_matrix!(x_ .* x_', scale) + x_ = vec(x) + return _init_identity_matrix!(x_ .* x_', scale) end @inline function _init_identity_matrix!(x::AbstractMatrix{T}, scale::T=T(1)) where {T} - x .= zero(T) - view(x, diagind(x)) .= scale .* true - return x + x .= zero(T) + view(x, diagind(x)) .= scale .* true + return x end @inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) # Compute norm over all dimensions except `except_dim` @inline function _norm(x::AbstractArray{T, N}, except_dim) where {T, N} - _norm(x; dims=filter(i -> i != except_dim, 1:N)) + return _norm(x; dims=filter(i -> i != except_dim, 1:N)) end diff --git a/test/runtests.jl b/test/runtests.jl index 824f4b3d..12d0bca7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,274 +1,274 @@ using CUDA, DeepEquilibriumNetworks, Functors, LinearAlgebra, Lux, Random, Test, Zygote function test_gradient_isfinite(gs::NamedTuple) - gradient_is_finite = [true] - function is_gradient_finite(x) - if !isnothing(x) && !all(isfinite, x) - gradient_is_finite[1] = false - end - return x + gradient_is_finite = [true] + function is_gradient_finite(x) + if !isnothing(x) && !all(isfinite, x) + gradient_is_finite[1] = false end - fmap(is_gradient_finite, gs) - return gradient_is_finite[1] + return x + end + fmap(is_gradient_finite, gs) + return gradient_is_finite[1] end @testset "DeepEquilibriumNetworks.jl" begin - seed = 0 - rng = Random.default_rng() - Random.seed!(rng, seed) - - @info "Testing DEQ" - model = DEQChain(Dense(2, 2), - DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; bias=false), - Dense(2, 2; bias=false)), - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0))) - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 2, 1)) - y = gpu(rand(rng, Float32, 2, 1)) - - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] - - @test test_gradient_isfinite(gs) - - @info "Testing DEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQ" - Random.seed!(rng, seed) - model = DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - Dense(2, 2), - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10))) - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 2, 1)) - y = gpu(rand(rng, Float32, 2, 1)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQV2" - Random.seed!(rng, seed) - model = DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - nothing, - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10))) - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 2, 1)) - y = gpu(rand(rng, Float32, 2, 1)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQV2 without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQ with Broyden Solver" - Random.seed!(rng, seed) - model = DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - Dense(2, 2), - DiscreteDEQSolver(BroydenSolver(); + seed = 0 + rng = Random.default_rng() + Random.seed!(rng, seed) + + @info "Testing DEQ" + model = DEQChain(Dense(2, 2), + DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; bias=false), + Dense(2, 2; bias=false)), + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0))) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] + + @test test_gradient_isfinite(gs) + + @info "Testing DEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ" + Random.seed!(rng, seed) + model = DEQChain(Dense(2, 2), + SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10))) - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 2, 1)) - y = gpu(rand(rng, Float32, 2, 1)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing SkipDEQ with L-Broyden Solver" - Random.seed!(rng, seed) - model = DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - Dense(2, 2), - DiscreteDEQSolver(LimitedMemoryBroydenSolver(); + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10))) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQV2" + Random.seed!(rng, seed) + model = DEQChain(Dense(2, 2), + SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), + nothing, + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10))) - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 2, 1)) - y = gpu(rand(rng, Float32, 2, 1)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleDEQ" - Random.seed!(rng, seed) - model = MultiScaleDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), - Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh)), - [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], - nothing, - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0), - ((4,), (3,), (2,), (1,)); - sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, - 10)) - - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 4, 2)) - y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleSkipDEQ" - Random.seed!(rng, seed) - model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), - Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh)), - [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], - nothing, - (Dense(4, 4, tanh), Dense(4, 3, tanh), - Dense(4, 2, tanh), Dense(4, 1, tanh)), - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0), - ((4,), (3,), (2,), (1,)); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10)) - - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 4, 2)) - y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleSkipDEQV2" - Random.seed!(rng, seed) - model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), - Dense(4, 4, tanh)), - Dense(3, 3, tanh), - Dense(2, 2, tanh), - Dense(1, 1, tanh)), - [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) - Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) - Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) - Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], - nothing, - nothing, - ContinuousDEQSolver(; abstol=0.1f0, - reltol=0.1f0, - abstol_termination=0.1f0, - reltol_termination=0.1f0), - ((4,), (3,), (2,), (1,)); - sensealg=DeepEquilibriumAdjoint(0.1f0, - 0.1f0, 10)) - - ps, st = gpu.(Lux.setup(rng, model)) - x = gpu(rand(rng, Float32, 4, 2)) - y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) - - @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, Val(5)) - - gs = gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end[1] - - @test test_gradient_isfinite(gs) + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10))) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQV2 without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ with Broyden Solver" + Random.seed!(rng, seed) + model = DEQChain(Dense(2, 2), + SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + DiscreteDEQSolver(BroydenSolver(); + abstol_termination=0.1f0, + reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10))) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ with L-Broyden Solver" + Random.seed!(rng, seed) + model = DEQChain(Dense(2, 2), + SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + DiscreteDEQSolver(LimitedMemoryBroydenSolver(); + abstol_termination=0.1f0, + reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10))) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleDEQ" + Random.seed!(rng, seed) + model = MultiScaleDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), + Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, + 10)) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQ" + Random.seed!(rng, seed) + model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), + Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + (Dense(4, 4, tanh), Dense(4, 3, tanh), + Dense(4, 2, tanh), Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10)) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQV2" + Random.seed!(rng, seed) + model = MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh), + Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh)), + [NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer()], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, + reltol=0.1f0, + abstol_termination=0.1f0, + reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, + 0.1f0, 10)) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + return sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) end