Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix code #57

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 2
format_docstrings = true
join_lines_based_on_source = true
separate_kwargs_with_semicolon = true
always_for_in = true
Comment on lines +2 to +9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most are these except the indent one just enforce some "stricter" conventions compared to SciMLStyle. Also gets rid of some conventions which I find weird, especially the return last value without an explicit return statement. Some of them need to be removed (I just copied it over from one of my other repos).

6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -18,6 +17,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -31,13 +31,13 @@ ChainRulesCore = "1"
ComponentArrays = "0.11, 0.12"
DiffEqBase = "6"
DiffEqCallbacks = "2.20.1"
DiffEqSensitivity = "6.64"
Functors = "0.2, 0.3"
LinearSolve = "1"
Lux = "0.4"
MLUtils = "0.2"
OrdinaryDiffEq = "6"
SciMLBase = "1.19"
SciMLSensitivity = "7"
Setfield = "1"
Static = "0.6, 0.7"
SteadyStateDiffEq = "1.6"
Expand Down
25 changes: 13 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
using Documenter, DocumenterCitations, DeepEquilibriumNetworks

bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"), sorting=:nyt)
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt)

makedocs(bib,
makedocs(bib;
sitename="Fast Deep Equilibrium Networks",
authors="Avik Pal et al.",
clean=true,
doctest=false,
modules=[DeepEquilibriumNetworks],
format=Documenter.HTML(# analytics = "",
# assets = ["assets/favicon.ico"],
;
canonical="https://deepequilibriumnetworks.sciml.ai/stable/"),
pages=[
"Home" => "index.md",
"Manual" => [
"Dynamical Systems" => "manual/solvers.md",
"Non Linear Solvers" => "manual/nlsolve.md",
"General Purpose Layers" => "manual/layers.md",
"DEQ Layers" => "manual/deqs.md",
"Miscellaneous" => "manual/misc.md",
],
"References" => "references.md",
"Home" => "index.md",
"Manual" => [
"Dynamical Systems" => "manual/solvers.md",
"Non Linear Solvers" => "manual/nlsolve.md",
"General Purpose Layers" => "manual/layers.md",
"DEQ Layers" => "manual/deqs.md",
"Miscellaneous" => "manual/misc.md",
],
"References" => "references.md",
])

deploydocs(repo="github.com/SciML/DeepEquilibriumNetworks.jl.git";
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git",
push_preview=true)
4 changes: 2 additions & 2 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ChainRulesCore,
CUDA,
DiffEqBase,
DiffEqCallbacks,
DiffEqSensitivity,
SciMLSensitivity,
Functors,
LinearAlgebra,
LinearSolve,
Expand All @@ -20,7 +20,7 @@ using ChainRulesCore,
UnPack,
Zygote

import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm
import SciMLSensitivity: AbstractAdjointSensitivityAlgorithm
import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates,
parameterlength, statelength
import Random: AbstractRNG
Expand Down
125 changes: 63 additions & 62 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,83 +1,84 @@
neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x
neg(nt::NamedTuple) = fmap(neg, nt)

@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution,
sensealg::DeepEquilibriumAdjoint,
g::Nothing, dg;
save_idxs=nothing)
@unpack f, p, u0 = sol.prob
@noinline function SciMLSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution,
sensealg::DeepEquilibriumAdjoint,
g::Nothing, dg;
save_idxs=nothing, kwargs...)
@unpack f, p, u0 = sol.prob

diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f;
quad=false, needs_jac=false)
diffcache, y = SciMLSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f;
quad=false, needs_jac=false)

_save_idxs = save_idxs === nothing ? Colon() : save_idxs
if dg !== nothing
if typeof(_save_idxs) <: Number
diffcache.dg_val[_save_idxs] = dg[_save_idxs]
elseif typeof(dg) <: Number
@. diffcache.dg_val[_save_idxs] = dg
else
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs]
end
end

if check_adjoint_mode(sensealg, Val(:vanilla))
# Solve the Linear Problem
_val, back = Zygote.pullback(x -> f(x, p, nothing), y)
s_val = size(_val)
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back,
s_val)
linear_problem = LinearProblem(op, vec(diffcache.dg_val))
λ = solve(linear_problem, sensealg.linsolve).u
elseif check_adjoint_mode(sensealg, Val(:jfb))
# Jacobian Free Backpropagation
λ = diffcache.dg_val
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
if dg !== nothing
if typeof(_save_idxs) <: Number
diffcache.dg_val[_save_idxs] = dg[_save_idxs]
elseif typeof(dg) <: Number
@. diffcache.dg_val[_save_idxs] = dg
else
error("Unknown adjoint mode")
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs]
end
end

# Compute the VJP
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p)
dp = back(vec(λ))[1]
if check_adjoint_mode(sensealg, Val(:vanilla))
# Solve the Linear Problem
_val, back = Zygote.pullback(x -> f(x, p, nothing), y)
s_val = size(_val)
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back,
s_val)
linear_problem = LinearProblem(op, vec(diffcache.dg_val))
λ = solve(linear_problem, sensealg.linsolve).u
elseif check_adjoint_mode(sensealg, Val(:jfb))
# Jacobian Free Backpropagation
λ = diffcache.dg_val
else
error("Unknown adjoint mode")
end

return neg(dp)
# Compute the VJP
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p)
dp = back(vec(λ))[1]

return neg(dp)
end

function DiffEqBase._concrete_solve_adjoint(prob::SteadyStateProblem, alg,
sensealg::DeepEquilibriumAdjoint, u0, p,
args...; save_idxs=nothing, kwargs...)
_prob = remake(prob; u0=u0, p=p)
sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
::SciMLBase.ADOriginator, args...;
save_idxs=nothing, kwargs...)
_prob = remake(prob; u0=u0, p=p)
sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

out = save_idxs === nothing ? sol :
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs])
out = save_idxs === nothing ? sol :
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs])

function steadystatebackpass(Δ)
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ,
save_idxs=save_idxs)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
dp,
NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
function steadystatebackpass(Δ)
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ,
save_idxs=save_idxs)
return (NoTangent(),
NoTangent(),
NoTangent(),
NoTangent(),
dp,
NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end

return out, steadystatebackpass
return out, steadystatebackpass
end

function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg, g, dg=nothing; abstol=1e-6,
reltol=1e-3, kwargs...)
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg, g, dg=nothing; abstol=1e-6,
reltol=1e-3, kwargs...)
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
end

function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg; g=nothing, dg=nothing, abstol=1e-6,
reltol=1e-3, kwargs...)
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint,
alg; g=nothing, dg=nothing, abstol=1e-6,
reltol=1e-3, kwargs...)
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...)
end
50 changes: 25 additions & 25 deletions src/layers/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,47 @@
Sequence of layers divided into 3 chunks --
* `pre_deq` -- layers that are executed before DEQ is applied
* `deq` -- The Deep Equilibrium Layer
* `post_deq` -- layers that are executed after DEQ is applied
- `pre_deq` -- layers that are executed before DEQ is applied
- `deq` -- The Deep Equilibrium Layer
- `post_deq` -- layers that are executed after DEQ is applied
Constraint: Must have one DEQ layer in `layers`
"""
struct DEQChain{P1, D, P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)}
pre_deq::P1
deq::D
post_deq::P2
pre_deq::P1
deq::D
post_deq::P2
end

function DEQChain(layers...)
pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false
for l in layers
if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork
@assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!"
deq = l
encounter_deq = true
continue
end
push!(encounter_deq ? post_deq : pre_deq, l)
pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false
for l in layers
if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork
@assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!"
deq = l
encounter_deq = true
continue
end
@assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...)
post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...)
return DEQChain(pre_deq, deq, post_deq)
push!(encounter_deq ? post_deq : pre_deq, l)
end
@assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...)
post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...)
return DEQChain(pre_deq, deq, post_deq)
end

function get_deq_return_type(deq::DEQChain{P1,
<:Union{MultiScaleDeepEquilibriumNetwork,
MultiScaleSkipDeepEquilibriumNetwork}},
::T) where {P1, T}
return NTuple{length(deq.deq.scales), T}
return NTuple{length(deq.deq.scales), T}
end
get_deq_return_type(::DEQChain, ::T) where {T} = T

function (deq::DEQChain)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple)
T = get_deq_return_type(deq, x)
x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq)
(x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq)
return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3)
T = get_deq_return_type(deq, x)
x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq)
(x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq)
return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3)
end
33 changes: 17 additions & 16 deletions src/layers/core.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end

function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork)
return (model=initialstates(rng, deq.model), fixed_depth=Val(0))
return (model=initialstates(rng, deq.model), fixed_depth=Val(0))
end

abstract type AbstractSkipDeepEquilibriumNetwork <:
AbstractExplicitContainerLayer{(:model, :shortcut)} end

function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork)
return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut),
fixed_depth=Val(0))
return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut),
fixed_depth=Val(0))
end

@inline check_unrolled_mode(::Val{0})::Bool = false
Expand All @@ -27,27 +27,28 @@ ChainRulesCore.@non_differentiable get_unrolled_depth(::Any)
Stores the solution of a DeepEquilibriumNetwork and its variants.
## Fields
* `z_star`: Steady-State or the value reached due to maxiters
* `u₀`: Initial Condition
* `residual`: Difference of the ``z^*`` and ``f(z^*, x)``
* `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed)
* `nfe`: Number of Function Evaluations
"""
struct DeepEquilibriumSolution{T, R <: AbstractFloat}
z_star::T
u₀::T
residual::T
jacobian_loss::R
nfe::Int
z_star::T
u₀::T
residual::T
jacobian_loss::R
nfe::Int
end

function Base.show(io::IO, l::DeepEquilibriumSolution)
print(io, "DeepEquilibriumSolution(")
print(io, "z_star: ", l.z_star)
print(io, ", initial_condition: ", l.u₀)
print(io, ", residual: ", l.residual)
print(io, ", jacobian_loss: ", l.jacobian_loss)
print(io, ", NFE: ", l.nfe)
print(io, ")")
return nothing
print(io, "DeepEquilibriumSolution(")
print(io, "z_star: ", l.z_star)
print(io, ", initial_condition: ", l.u₀)
print(io, ", residual: ", l.residual)
print(io, ", jacobian_loss: ", l.jacobian_loss)
print(io, ", NFE: ", l.nfe)
print(io, ")")
return nothing
end
Loading