-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #57 from SciML/ap/fix
Fix code
- Loading branch information
Showing
19 changed files
with
1,077 additions
and
1,083 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,9 @@ | ||
style = "sciml" | ||
whitespace_in_kwargs = false | ||
whitespace_in_kwargs = false | ||
always_use_return = true | ||
margin = 92 | ||
indent = 2 | ||
format_docstrings = true | ||
join_lines_based_on_source = true | ||
separate_kwargs_with_semicolon = true | ||
always_for_in = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,14 @@ | ||
name = "DeepEquilibriumNetworks" | ||
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" | ||
authors = ["Avik Pal <[email protected]>"] | ||
version = "0.1.1" | ||
version = "0.1.2" | ||
|
||
[deps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" | ||
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" | ||
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" | ||
|
@@ -18,6 +17,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | |
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" | ||
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" | ||
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" | ||
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
@@ -31,13 +31,13 @@ ChainRulesCore = "1" | |
ComponentArrays = "0.11, 0.12" | ||
DiffEqBase = "6" | ||
DiffEqCallbacks = "2.20.1" | ||
DiffEqSensitivity = "6.64" | ||
Functors = "0.2, 0.3" | ||
LinearSolve = "1" | ||
Lux = "0.4" | ||
MLUtils = "0.2" | ||
OrdinaryDiffEq = "6" | ||
SciMLBase = "1.19" | ||
SciMLSensitivity = "7" | ||
Setfield = "1" | ||
Static = "0.6, 0.7" | ||
SteadyStateDiffEq = "1.6" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,28 @@ | ||
using Documenter, DocumenterCitations, DeepEquilibriumNetworks | ||
|
||
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"), sorting=:nyt) | ||
bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt) | ||
|
||
makedocs(bib, | ||
makedocs(bib; | ||
sitename="Fast Deep Equilibrium Networks", | ||
authors="Avik Pal et al.", | ||
clean=true, | ||
doctest=false, | ||
modules=[DeepEquilibriumNetworks], | ||
format=Documenter.HTML(# analytics = "", | ||
# assets = ["assets/favicon.ico"], | ||
; | ||
canonical="https://deepequilibriumnetworks.sciml.ai/stable/"), | ||
pages=[ | ||
"Home" => "index.md", | ||
"Manual" => [ | ||
"Dynamical Systems" => "manual/solvers.md", | ||
"Non Linear Solvers" => "manual/nlsolve.md", | ||
"General Purpose Layers" => "manual/layers.md", | ||
"DEQ Layers" => "manual/deqs.md", | ||
"Miscellaneous" => "manual/misc.md", | ||
], | ||
"References" => "references.md", | ||
"Home" => "index.md", | ||
"Manual" => [ | ||
"Dynamical Systems" => "manual/solvers.md", | ||
"Non Linear Solvers" => "manual/nlsolve.md", | ||
"General Purpose Layers" => "manual/layers.md", | ||
"DEQ Layers" => "manual/deqs.md", | ||
"Miscellaneous" => "manual/misc.md", | ||
], | ||
"References" => "references.md", | ||
]) | ||
|
||
deploydocs(repo="github.com/SciML/DeepEquilibriumNetworks.jl.git"; | ||
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", | ||
push_preview=true) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,84 @@ | ||
neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x | ||
neg(nt::NamedTuple) = fmap(neg, nt) | ||
|
||
@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, | ||
sensealg::DeepEquilibriumAdjoint, | ||
g::Nothing, dg; | ||
save_idxs=nothing) | ||
@unpack f, p, u0 = sol.prob | ||
@noinline function SciMLSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, | ||
sensealg::DeepEquilibriumAdjoint, | ||
g::Nothing, dg; | ||
save_idxs=nothing, kwargs...) | ||
@unpack f, p, u0 = sol.prob | ||
|
||
diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; | ||
quad=false, needs_jac=false) | ||
diffcache, y = SciMLSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; | ||
quad=false, needs_jac=false) | ||
|
||
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
if dg !== nothing | ||
if typeof(_save_idxs) <: Number | ||
diffcache.dg_val[_save_idxs] = dg[_save_idxs] | ||
elseif typeof(dg) <: Number | ||
@. diffcache.dg_val[_save_idxs] = dg | ||
else | ||
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs] | ||
end | ||
end | ||
|
||
if check_adjoint_mode(sensealg, Val(:vanilla)) | ||
# Solve the Linear Problem | ||
_val, back = Zygote.pullback(x -> f(x, p, nothing), y) | ||
s_val = size(_val) | ||
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back, | ||
s_val) | ||
linear_problem = LinearProblem(op, vec(diffcache.dg_val)) | ||
λ = solve(linear_problem, sensealg.linsolve).u | ||
elseif check_adjoint_mode(sensealg, Val(:jfb)) | ||
# Jacobian Free Backpropagation | ||
λ = diffcache.dg_val | ||
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
if dg !== nothing | ||
if typeof(_save_idxs) <: Number | ||
diffcache.dg_val[_save_idxs] = dg[_save_idxs] | ||
elseif typeof(dg) <: Number | ||
@. diffcache.dg_val[_save_idxs] = dg | ||
else | ||
error("Unknown adjoint mode") | ||
@. diffcache.dg_val[_save_idxs] = dg[_save_idxs] | ||
end | ||
end | ||
|
||
# Compute the VJP | ||
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) | ||
dp = back(vec(λ))[1] | ||
if check_adjoint_mode(sensealg, Val(:vanilla)) | ||
# Solve the Linear Problem | ||
_val, back = Zygote.pullback(x -> f(x, p, nothing), y) | ||
s_val = size(_val) | ||
op = ZygotePullbackMultiplyOperator{eltype(y), typeof(back), typeof(s_val)}(back, | ||
s_val) | ||
linear_problem = LinearProblem(op, vec(diffcache.dg_val)) | ||
λ = solve(linear_problem, sensealg.linsolve).u | ||
elseif check_adjoint_mode(sensealg, Val(:jfb)) | ||
# Jacobian Free Backpropagation | ||
λ = diffcache.dg_val | ||
else | ||
error("Unknown adjoint mode") | ||
end | ||
|
||
return neg(dp) | ||
# Compute the VJP | ||
_, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) | ||
dp = back(vec(λ))[1] | ||
|
||
return neg(dp) | ||
end | ||
|
||
function DiffEqBase._concrete_solve_adjoint(prob::SteadyStateProblem, alg, | ||
sensealg::DeepEquilibriumAdjoint, u0, p, | ||
args...; save_idxs=nothing, kwargs...) | ||
_prob = remake(prob; u0=u0, p=p) | ||
sol = solve(_prob, alg, args...; kwargs...) | ||
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
::SciMLBase.ADOriginator, args...; | ||
save_idxs=nothing, kwargs...) | ||
_prob = remake(prob; u0=u0, p=p) | ||
sol = solve(_prob, alg, args...; kwargs...) | ||
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
|
||
out = save_idxs === nothing ? sol : | ||
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) | ||
out = save_idxs === nothing ? sol : | ||
DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) | ||
|
||
function steadystatebackpass(Δ) | ||
# Δ = dg/dx or diffcache.dg_val | ||
# del g/del p = 0 | ||
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, | ||
save_idxs=save_idxs) | ||
return (NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
dp, | ||
NoTangent(), | ||
ntuple(_ -> NoTangent(), length(args))...) | ||
end | ||
function steadystatebackpass(Δ) | ||
# Δ = dg/dx or diffcache.dg_val | ||
# del g/del p = 0 | ||
dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, | ||
save_idxs=save_idxs) | ||
return (NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
NoTangent(), | ||
dp, | ||
NoTangent(), | ||
ntuple(_ -> NoTangent(), length(args))...) | ||
end | ||
|
||
return out, steadystatebackpass | ||
return out, steadystatebackpass | ||
end | ||
|
||
function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, | ||
alg, g, dg=nothing; abstol=1e-6, | ||
reltol=1e-3, kwargs...) | ||
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) | ||
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, | ||
alg, g, dg=nothing; abstol=1e-6, | ||
reltol=1e-3, kwargs...) | ||
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) | ||
end | ||
|
||
function DiffEqSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, | ||
alg; g=nothing, dg=nothing, abstol=1e-6, | ||
reltol=1e-3, kwargs...) | ||
return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) | ||
function SciMLSensitivity._adjoint_sensitivities(sol, sensealg::DeepEquilibriumAdjoint, | ||
alg; g=nothing, dg=nothing, abstol=1e-6, | ||
reltol=1e-3, kwargs...) | ||
return SciMLSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.