Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Steady State Adjoint #93

Merged
merged 11 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
format_docstrings = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = true
format_markdown = true
format_markdown = true
annotate_untyped_fields_with_any = false
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
- ADJOINT
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ profs
logs
benchmarking
*/tensorflow_datasets/
checkpoints
checkpoints
wip
26 changes: 11 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "1.3.0"
version = "1.4.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
CUDA = "3, 4, 5"
ChainRulesCore = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.119"
LinearSolve = "1, 2"
Lux = "0.4, 0.5"
MLUtils = "0.2, 0.3, 0.4"
Lux = "0.5.7"
NonlinearSolve = "2"
OrdinaryDiffEq = "6"
SciMLBase = "1.19, 2"
SciMLSensitivity = "7"
Reexport = "1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
Setfield = "1"
SimpleNonlinearSolve = "0.1.14"
Static = "0.6, 0.7, 0.8"
SteadyStateDiffEq = "1.16"
TruncatedStacktraces = "1.1"
Zygote = "0.6.34"
ZygoteRules = "0.2"
julia = "1.6"
julia = "1.9"
39 changes: 20 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,31 @@ Pkg.add("DeepEquilibriumNetworks")
## Quickstart

```julia
import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import Zygote
using DeepEquilibriumNetworks, Lux, Random, Zygote
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support

seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)

model = Lux.Chain(Lux.Dense(2, 2),
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
Lux.Dense(2, 2; use_bias=false),
Lux.Dense(2, 2; use_bias=false)),
DEQs.ContinuousDEQSolver(;
abstol=0.1f0,
reltol=0.1f0,
abstol_termination=0.1f0,
reltol_termination=0.1f0)))

ps, st = gpu.(Lux.setup(rng, model))
x = gpu(rand(rng, Float32, 2, 1))
y = gpu(rand(rng, Float32, 2, 1))

gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+,
Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
reltol_termination=0.1f0);
save_everystep=true))

gdev = gpu_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, model) |> gdev
x = rand(rng, Float32, 2, 1) |> gdev
y = rand(rng, Float32, 2, 1) |> gdev

model(x, ps, st)

gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
```

## Citation
Expand Down
40 changes: 20 additions & 20 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

DeepEquilibriumNetworks.jl is a framework built on top of
[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) and
[Lux.jl](https://docs.sciml.ai/Lux/stable/), enabling the efficient training and inference for
[Lux.jl](https://lux.csail.mit.edu/), enabling the efficient training and inference for
Deep Equilibrium Networks (Infinitely Deep Neural Networks).

## Installation
Expand All @@ -17,30 +17,30 @@ Pkg.add("DeepEquilibriumNetworks")
## Quick-start

```julia
import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import Zygote
using DeepEquilibriumNetworks, Lux, Random, Zygote
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support

seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+,
Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
reltol_termination=0.1f0);
save_everystep=true))

model = Lux.Chain(Lux.Dense(2, 2),
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
Lux.Dense(2, 2; use_bias=false),
Lux.Dense(2, 2; use_bias=false)),
DEQs.ContinuousDEQSolver(;
abstol=0.1f0,
reltol=0.1f0,
abstol_termination=0.1f0,
reltol_termination=0.1f0)))

ps, st = gpu.(Lux.setup(rng, model))
x = gpu(rand(rng, Float32, 2, 1))
y = gpu(rand(rng, Float32, 2, 1))

gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
gdev = gpu_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, model) |> gdev
x = rand(rng, Float32, 2, 1) |> gdev
y = rand(rng, Float32, 2, 1) |> gdev

model(x, ps, st)

gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
```

## Citation
Expand Down
1 change: 0 additions & 1 deletion docs/src/manual/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
DeepEquilibriumSolution
EquilibriumSolution
DeepEquilibriumNetworks.split_and_reshape
DeepEquilibriumNetworks.init_identity_matrix
DeepEquilibriumNetworks.estimate_jacobian_trace
```
44 changes: 0 additions & 44 deletions experiments/Project.toml

This file was deleted.

74 changes: 0 additions & 74 deletions experiments/cifar10/large.yml

This file was deleted.

Loading