Skip to content

Commit

Permalink
Merge pull request #154 from SciML/ap/updates
Browse files Browse the repository at this point in the history
Update the code to the latest versions of Lux
  • Loading branch information
avik-pal authored Jun 19, 2024
2 parents 7bf3127 + ce97811 commit dbcad8d
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 164 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "2.1.0"
version = "2.1.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -13,21 +13,19 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
ADTypes = "0.2.5, 1"
Expand All @@ -37,16 +35,18 @@ CommonSolve = "0.2.4"
ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
Documenter = "1.4"
ExplicitImports = "1.6.0"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.38"
Lux = "0.5.50"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
NLsolve = "4.5.1"
NNlib = "0.9.17"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Expand All @@ -63,7 +63,9 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Expand All @@ -77,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ Random = "1"
SciMLSensitivity = "7"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ makedocs(; sitename="Deep Equilibrium Networks",
authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks],
clean=true,
doctest=true,
doctest=false, # Tested in CI
linkcheck=true,
format=Documenter.HTML(; assets=["assets/favicon.ico"],
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
Expand Down
13 changes: 7 additions & 6 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ const cdev = cpu_device()
const gdev = gpu_device()
```

We can now construct our dataloader.
We can now construct our dataloader. We are using only limited part of the data for
demonstration.

```@example basic_mnist_deq
function onehot(labels_raw)
Expand All @@ -32,17 +33,17 @@ function loadmnist(batchsize, split)
mnist = MNIST(; split)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
gdev
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))[
:, :, 1:1, 1:128] |> gdev
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> gdev
y_train = onehot(labels_raw)[:, 1:128] |> gdev
y_train = batchview(y_train, batchsize)
return x_train, y_train
end
x_train, y_train = loadmnist(128, :train);
x_test, y_test = loadmnist(128, :test);
x_train, y_train = loadmnist(16, :train);
x_test, y_test = loadmnist(16, :test);
```

Construct the Lux Neural Network containing a DEQ layer.
Expand Down
57 changes: 0 additions & 57 deletions ext/DeepEquilibriumNetworksZygoteExt.jl

This file was deleted.

5 changes: 3 additions & 2 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DeepEquilibriumNetworks
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AutoFiniteDiff
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoZygote
using ChainRulesCore: ChainRulesCore
using CommonSolve: solve
using ConcreteStructs: @concrete
Expand All @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations
using FastClosures: @closure
using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer,
StatefulLuxLayer, WrappedFunction
using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using NNlib:
using Random: Random, AbstractRNG, randn!
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm,
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
Expand Down
107 changes: 25 additions & 82 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ const DEQ = DeepEquilibriumNetwork

ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType}

function Lux.initialstates(rng::AbstractRNG, deq::DEQ)
rng = Lux.replicate(rng)
function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ)
rng = LuxCore.replicate(rng)
randn(rng, 1)
return (; model=Lux.initialstates(rng, deq.model), fixed_depth=Val(0),
init=Lux.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
return (; model=LuxCore.initialstates(rng, deq.model), fixed_depth=Val(0),
init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
end

(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, __check_unrolled_mode(st))
Expand All @@ -79,10 +79,10 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth)

z_star, st_ = repeated_model((z, x), ps.model, st.model)
model = StatefulLuxLayer(deq.model, ps.model, st_)
model = StatefulLuxLayer{true}(deq.model, ps.model, st_)
resid = CRC.ignore_derivatives(z_star .- model((z_star, x)))

rng = Lux.replicate(st.rng)
rng = LuxCore.replicate(st.rng)
jac_loss = __estimate_jacobian_trace(
__getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)

Expand All @@ -97,7 +97,7 @@ end
function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
z, st = __get_initial_condition(deq, x, ps, st)

model = StatefulLuxLayer(deq.model, ps.model, st.model)
model = StatefulLuxLayer{true}(deq.model, ps.model, st.model)

dudt = @closure (u, p, t) -> begin
# The type-assert is needed because of an upstream Lux issue with type stability of
Expand All @@ -113,7 +113,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...)
z_star = __get_steady_state(sol)

rng = Lux.replicate(st.rng)
rng = LuxCore.replicate(st.rng)
jac_loss = __estimate_jacobian_trace(
__getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)

Expand Down Expand Up @@ -144,17 +144,16 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
condition is set to `zero(x)`. If `missing`, the initial condition is set to
`WrappedFunction(zero)`. In other cases the initial condition is set to
`init(x, ps, st)`.
- `jacobian_regularization`: Must be one of `nothing`, `AutoFiniteDiff` or `AutoZygote`.
- `jacobian_regularization`: Must be one of `nothing`, `AutoForwardDiff`, `AutoFiniteDiff`
or `AutoZygote`.
- `problem_type`: Provides a way to simulate a Vanilla Neural ODE by setting the
`problem_type` to `ODEProblem`. By default, the problem type is set to
`SteadyStateProblem`.
- `kwargs`: Additional Parameters that are directly passed to `SciMLBase.solve`.
## Example
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
```jldoctest
julia> model = DeepEquilibriumNetwork(
Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
VCABM3(); verbose=false)
Expand All @@ -168,13 +167,12 @@ DeepEquilibriumNetwork(
) # Total: 8 parameters,
# plus 0 states.
julia> rng = Random.default_rng()
TaskLocalRNG()
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> model(ones(Float32, 2, 1), ps, st);
julia> size(first(model(ones(Float32, 2, 1), ps, st)))
(2, 1)
```
See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref),
Expand Down Expand Up @@ -234,84 +232,27 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
## Example
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve
```jldoctest
julia> main_layers = (
Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh))
(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast))
Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));
julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
4×4 Matrix{LuxCore.AbstractExplicitLayer}:
NoOpLayer() … Dense(4 => 1, tanh_fast)
Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast)
Dense(2 => 4, tanh_fast) Dense(2 => 1, tanh_fast)
Dense(1 => 4, tanh_fast) NoOpLayer()
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];
julia> model = MultiScaleDeepEquilibriumNetwork(
main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)))
DeepEquilibriumNetwork(
model = MultiScaleInputLayer{scales = 4}(
model = Chain(
layer_1 = Parallel(
layer_1 = Parallel(
+
Dense(4 => 4, tanh_fast, bias=false), # 16 parameters
Dense(4 => 4, tanh_fast, bias=false), # 16 parameters
),
layer_2 = Dense(3 => 3, tanh_fast), # 12 parameters
layer_3 = Dense(2 => 2, tanh_fast), # 6 parameters
layer_4 = Dense(1 => 1, tanh_fast), # 2 parameters
),
layer_2 = BranchLayer(
layer_1 = Parallel(
+
NoOpLayer(),
Dense(3 => 4, tanh_fast), # 16 parameters
Dense(2 => 4, tanh_fast), # 12 parameters
Dense(1 => 4, tanh_fast), # 8 parameters
),
layer_2 = Parallel(
+
Dense(4 => 3, tanh_fast), # 15 parameters
NoOpLayer(),
Dense(2 => 3, tanh_fast), # 9 parameters
Dense(1 => 3, tanh_fast), # 6 parameters
),
layer_3 = Parallel(
+
Dense(4 => 2, tanh_fast), # 10 parameters
Dense(3 => 2, tanh_fast), # 8 parameters
NoOpLayer(),
Dense(1 => 2, tanh_fast), # 4 parameters
),
layer_4 = Parallel(
+
Dense(4 => 1, tanh_fast), # 5 parameters
Dense(3 => 1, tanh_fast), # 4 parameters
Dense(2 => 1, tanh_fast), # 3 parameters
NoOpLayer(),
),
),
),
),
init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Val{((4,), (3,), (2,), (1,))}}(DeepEquilibriumNetworks.__zeros_init, Val{((4,), (3,), (2,), (1,))}())),
) # Total: 152 parameters,
# plus 0 states.
main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));
julia> rng = Random.default_rng()
TaskLocalRNG()
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> x = rand(rng, Float32, 4, 12);
julia> model(x, ps, st);
julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
```
"""
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
Expand Down Expand Up @@ -390,7 +331,9 @@ end
function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N}
return MultiScaleInputLayer{N}
end
Lux.display_name(::MultiScaleInputLayer{N}) where {N} = "MultiScaleInputLayer{scales = $N}"
function LuxCore.display_name(::MultiScaleInputLayer{N}) where {N}
return "MultiScaleInputLayer{scales = $N}"
end

function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S}
return MultiScaleInputLayer{length(S)}(model, split_idxs, scales)
Expand All @@ -401,7 +344,7 @@ end
return quote
u, x = z
u_ = __split_and_reshape(u, m.split_idxs, m.scales)
u_res, st = Lux.apply(m.model, ($(inputs...),), ps, st)
u_res, st = LuxCore.apply(m.model, ($(inputs...),), ps, st)
return __flatten_vcat(u_res), st
end
end
Loading

2 comments on commit dbcad8d

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109371

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.1.1 -m "<description of version>" dbcad8dae399068fc80cbadfee7527a0011d7b73
git push origin v2.1.1

Please sign in to comment.