Skip to content

Commit

Permalink
Merge pull request #57 from SciML/ap/fix
Browse files Browse the repository at this point in the history
Fix code
  • Loading branch information
avik-pal authored Jul 5, 2022
2 parents 089e7c7 + 79270f0 commit 58a444f
Show file tree
Hide file tree
Showing 19 changed files with 1,077 additions and 1,083 deletions.
9 changes: 8 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 2
format_docstrings = true
join_lines_based_on_source = true
separate_kwargs_with_semicolon = true
always_for_in = true
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -18,6 +17,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -31,13 +31,13 @@ ChainRulesCore = "1"
ComponentArrays = "0.11, 0.12"
DiffEqBase = "6"
DiffEqCallbacks = "2.20.1"
DiffEqSensitivity = "6.64"
Functors = "0.2, 0.3"
LinearSolve = "1"
Lux = "0.4"
MLUtils = "0.2"
OrdinaryDiffEq = "6"
SciMLBase = "1.19"
SciMLSensitivity = "7"
Setfield = "1"
Static = "0.6, 0.7"
SteadyStateDiffEq = "1.6"
Expand Down
25 changes: 13 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
using Documenter, DocumenterCitations, DeepEquilibriumNetworks

bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"), sorting=:nyt)
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt)

makedocs(bib,
makedocs(bib;
sitename="Fast Deep Equilibrium Networks",
authors="Avik Pal et al.",
clean=true,
doctest=false,
modules=[DeepEquilibriumNetworks],
format=Documenter.HTML(# analytics = "",
# assets = ["assets/favicon.ico"],
;
canonical="https://deepequilibriumnetworks.sciml.ai/stable/"),
pages=[
"Home" => "index.md",
"Manual" => [
"Dynamical Systems" => "manual/solvers.md",
"Non Linear Solvers" => "manual/nlsolve.md",
"General Purpose Layers" => "manual/layers.md",
"DEQ Layers" => "manual/deqs.md",
"Miscellaneous" => "manual/misc.md",
],
"References" => "references.md",
"Home" => "index.md",
"Manual" => [
"Dynamical Systems" => "manual/solvers.md",
"Non Linear Solvers" => "manual/nlsolve.md",
"General Purpose Layers" => "manual/layers.md",
"DEQ Layers" => "manual/deqs.md",
"Miscellaneous" => "manual/misc.md",
],
"References" => "references.md",
])

deploydocs(repo="github.com/SciML/DeepEquilibriumNetworks.jl.git";
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git",
push_preview=true)
4 changes: 2 additions & 2 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ChainRulesCore,
CUDA,
DiffEqBase,
DiffEqCallbacks,
DiffEqSensitivity,
SciMLSensitivity,
Functors,
LinearAlgebra,
LinearSolve,
Expand All @@ -20,7 +20,7 @@ using ChainRulesCore,
UnPack,
Zygote

import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm
import SciMLSensitivity: AbstractAdjointSensitivityAlgorithm
import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates,
parameterlength, statelength
import Random: AbstractRNG
Expand Down
125 changes: 63 additions & 62 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,83 +1,84 @@
neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x
neg(nt::NamedTuple) = fmap(neg, nt)

@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution,
sensealg::DeepEquilibriumAdjoint,
g::Nothing, dg;
save_idxs=nothing)
@unpack f, p, u0 = sol.prob
@noinline function SciMLSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution,
sensealg::DeepEquilibriumAdjoint,
g::Nothing, dg;
save_idxs=nothing, kwargs...)
@unpack f, p, u0 = sol.prob

diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f;
quad=false, needs_jac=false)
diffcache, y = SciMLSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f;
quad=false, needs_jac=false)

_save_idxs = save_idxs === nothing ? Colon() : save_idxs
if dg !== nothing
if typeof(_save_idxs) <: Number
diffcache.dg_val[_save_idxs] = dg[_save_idxs]
elseif typeof(dg) <: Number
@. diffcache.dg_val[_save_idxs] = dg
else
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs]
end
end

if check_adjoint_mode(sensealg, Val(:vanilla))
# Solve the Linear Problem
_val, back = Zygote.pullback(x -> f(x, p, nothing), y)
s_val = size(_val)
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back,
s_val)
linear_problem = LinearProblem(op, vec(diffcache.dg_val))
λ = solve(linear_problem, sensealg.linsolve).u
elseif check_adjoint_mode(sensealg, Val(:jfb))
# Jacobian Free Backpropagation
λ = diffcache.dg_val
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
if dg !== nothing
if typeof(_save_idxs) <: Number
diffcache.dg_val[_save_idxs] = dg[_save_idxs]
elseif typeof(dg) <: Number
@. diffcache.dg_val[_save_idxs] = dg
else
error("Unknown adjoint mode")
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs]
end
end

# Compute the VJP
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p)
dp = back(vec(λ))[1]
if check_adjoint_mode(sensealg, Val(:vanilla))
# Solve the Linear Problem
_val, back = Zygote.pullback(x -> f(x, p, nothing), y)
s_val = size(_val)
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back,
s_val)
linear_problem = LinearProblem(op, vec(diffcache.dg_val))
λ = solve(linear_problem, sensealg.linsolve).u
elseif check_adjoint_mode(sensealg, Val(:jfb))
# Jacobian Free Backpropagation
λ = diffcache.dg_val
else
error("Unknown adjoint mode")
end

return neg(dp)
# Compute the VJP
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p)
dp = back(vec(λ))[1]

return neg(dp)
end

function DiffEqBase._concrete_solve_adjoint(prob::SteadyStateProblem, alg,
sensealg::DeepEquilibriumAdjoint, u0, p,
args...; save_idxs=nothing, kwargs...)
_prob = remake(prob; u0=u0, p=p)
sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
::SciMLBase.ADOriginator, args...;
save_idxs=nothing, kwargs...)
_prob = remake(prob; u0=u0, p=p)
sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

out = save_idxs === nothing ? sol :
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs])
out = save_idxs === nothing ? sol :
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs])

function steadystatebackpass(Δ)
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ,
save_idxs=save_idxs)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
dp,
NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
function steadystatebackpass(Δ)
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ,
save_idxs=save_idxs)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
dp,
NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end

return out, steadystatebackpass
return out, steadystatebackpass
end

function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg, g, dg=nothing; abstol=1e-6,
reltol=1e-3, kwargs...)
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg, g, dg=nothing; abstol=1e-6,
reltol=1e-3, kwargs...)
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
end

function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg; g=nothing, dg=nothing, abstol=1e-6,
reltol=1e-3, kwargs...)
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg; g=nothing, dg=nothing, abstol=1e-6,
reltol=1e-3, kwargs...)
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
end
50 changes: 25 additions & 25 deletions src/layers/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,47 @@
Sequence of layers divided into 3 chunks --
* `pre_deq` -- layers that are executed before DEQ is applied
* `deq` -- The Deep Equilibrium Layer
* `post_deq` -- layers that are executed after DEQ is applied
- `pre_deq` -- layers that are executed before DEQ is applied
- `deq` -- The Deep Equilibrium Layer
- `post_deq` -- layers that are executed after DEQ is applied
Constraint: Must have one DEQ layer in `layers`
"""
struct DEQChain{P1, D, P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)}
pre_deq::P1
deq::D
post_deq::P2
pre_deq::P1
deq::D
post_deq::P2
end

function DEQChain(layers...)
pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false
for l in layers
if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork
@assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!"
deq = l
encounter_deq = true
continue
end
push!(encounter_deq ? post_deq : pre_deq, l)
pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false
for l in layers
if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork
@assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!"
deq = l
encounter_deq = true
continue
end
@assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...)
post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...)
return DEQChain(pre_deq, deq, post_deq)
push!(encounter_deq ? post_deq : pre_deq, l)
end
@assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...)
post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...)
return DEQChain(pre_deq, deq, post_deq)
end

function get_deq_return_type(deq::DEQChain{P1,
<:Union{MultiScaleDeepEquilibriumNetwork,
MultiScaleSkipDeepEquilibriumNetwork}},
::T) where {P1, T}
return NTuple{length(deq.deq.scales), T}
return NTuple{length(deq.deq.scales), T}
end
get_deq_return_type(::DEQChain, ::T) where {T} = T

function (deq::DEQChain)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple)
T = get_deq_return_type(deq, x)
x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq)
(x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq)
return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3)
T = get_deq_return_type(deq, x)
x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq)
(x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq)
return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3)
end
33 changes: 17 additions & 16 deletions src/layers/core.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end

function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork)
return (model=initialstates(rng, deq.model), fixed_depth=Val(0))
return (model=initialstates(rng, deq.model), fixed_depth=Val(0))
end

abstract type AbstractSkipDeepEquilibriumNetwork <:
AbstractExplicitContainerLayer{(:model, :shortcut)} end

function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork)
return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut),
fixed_depth=Val(0))
return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut),
fixed_depth=Val(0))
end

@inline check_unrolled_mode(::Val{0})::Bool = false
Expand All @@ -27,27 +27,28 @@ ChainRulesCore.@non_differentiable get_unrolled_depth(::Any)
Stores the solution of a DeepEquilibriumNetwork and its variants.
## Fields
* `z_star`: Steady-State or the value reached due to maxiters
* `u₀`: Initial Condition
* `residual`: Difference of the ``z^*`` and ``f(z^*, x)``
* `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed)
* `nfe`: Number of Function Evaluations
"""
struct DeepEquilibriumSolution{T, R <: AbstractFloat}
z_star::T
u₀::T
residual::T
jacobian_loss::R
nfe::Int
z_star::T
u₀::T
residual::T
jacobian_loss::R
nfe::Int
end

function Base.show(io::IO, l::DeepEquilibriumSolution)
print(io, "DeepEquilibriumSolution(")
print(io, "z_star: ", l.z_star)
print(io, ", initial_condition: ", l.u₀)
print(io, ", residual: ", l.residual)
print(io, ", jacobian_loss: ", l.jacobian_loss)
print(io, ", NFE: ", l.nfe)
print(io, ")")
return nothing
print(io, "DeepEquilibriumSolution(")
print(io, "z_star: ", l.z_star)
print(io, ", initial_condition: ", l.u₀)
print(io, ", residual: ", l.residual)
print(io, ", jacobian_loss: ", l.jacobian_loss)
print(io, ", NFE: ", l.nfe)
print(io, ")")
return nothing
end
Loading

0 comments on commit 58a444f

Please sign in to comment.