Skip to content

Commit

Permalink
Default to using SimpleGMRES for the backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2023
1 parent 8b1d48d commit 64f84a1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 12 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "2.0.0"
version = "2.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,11 +18,12 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = "SciMLSensitivity"
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"

[compat]
Expand All @@ -32,6 +33,7 @@ ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
LinearAlgebra = "1"
LinearSolve = "2.21.2"
Lux = "0.5.11"
Random = "1"
SciMLBase = "2"
Expand Down
38 changes: 38 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@ To construct a continuous DEQ, any ODE solver compatible with `DifferentialEquat
can be passed as the solver. To construct a discrete DEQ, any root finding algorithm
compatible with `NonlinearSolve.jl` API can be passed as the solver.

## Choosing a Solver

### Root Finding Algorithms

Using Root Finding Algorithms give fast convergence when possible, but these methods also
tend to be unstable. If you must use a root finding algorithm, we recommend using:

1. `NewtonRaphson` or `TrustRegion` for small models
2. `LimitedMemoryBroyden` for large Deep Learning applications (with well-conditioned
Jacobians)
3. `NewtonRaphson(; linsolve = KrylovJL_GMRES())` for cases when Broyden methods fail

Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models.
If you think this is causing a performance regression, please open an issue in
[Lux.jl](https://github.com/LuxDL/Lux.jl).

### ODE Solvers

Using ODE Solvers give slower convergence, but are more stable. We generally recommend these
methods over root finding algorithms. If you use implicit ODE solvers, remember to use
Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we
recommend:

1. `VCAB3()` for high tolerance problems
2. `Tsit5()` for high tolerance problems where `VCAB3()` fails
3. In all other cases, follow the recommendation given in [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/#ode_solve) documentation

### Sensitivity Analysis

1. For `MultiScaleNeuralODE`, we default to `GaussAdjoint(; autojacvec = ZygoteVJP())`. A
faster alternative would be `BacksolveAdjoint(; autojacvec = ZygoteVJP())` but there are
stability concerns for using that. Follow the recommendation given in [SciMLSensitivity.jl](https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#Choosing-a-Sensitivity-Algorithm) documentation.
2. For Steady State Problems, we default to
`SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)))`.
This default will perform poorly on small models. It is recommended to pass
`sensealg = SteadyStateAdjoint()` or
`sensealg = SteadyStateAdjoint(; linsolve = LUFactorization())` for small models.

## Standard Models

```@docs
Expand Down
18 changes: 18 additions & 0 deletions ext/DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt

# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity
# to load this extension
using LinearSolve, SciMLBase, SciMLSensitivity
import DeepEquilibriumNetworks: __default_sensealg

@inline function __default_sensealg(prob::SteadyStateProblem)
# We want to avoid the cost for cache construction for linsolve = nothing
# For small problems we should use concrete jacobian but we assume users want to solve
# large problems with this package so we default to GMRES and avoid runtime dispatches
linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)]))
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
end
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())

end
10 changes: 0 additions & 10 deletions ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl

This file was deleted.

0 comments on commit 64f84a1

Please sign in to comment.