Skip to content

Commit

Permalink
Update the code to the latest versions of Lux
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 19, 2024
1 parent 7bf3127 commit ce97811
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 164 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "2.1.0"
version = "2.1.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -13,21 +13,19 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
ADTypes = "0.2.5, 1"
Expand All @@ -37,16 +35,18 @@ CommonSolve = "0.2.4"
ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
Documenter = "1.4"
ExplicitImports = "1.6.0"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.38"
Lux = "0.5.50"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
NLsolve = "4.5.1"
NNlib = "0.9.17"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Expand All @@ -63,7 +63,9 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Expand All @@ -77,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ Random = "1"
SciMLSensitivity = "7"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ makedocs(; sitename="Deep Equilibrium Networks",
authors="Avik Pal et al.",
modules=[DeepEquilibriumNetworks],
clean=true,
doctest=true,
doctest=false, # Tested in CI
linkcheck=true,
format=Documenter.HTML(; assets=["assets/favicon.ico"],
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
Expand Down
13 changes: 7 additions & 6 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ const cdev = cpu_device()
const gdev = gpu_device()
```

We can now construct our dataloader.
We can now construct our dataloader. We are using only limited part of the data for
demonstration.

```@example basic_mnist_deq
function onehot(labels_raw)
Expand All @@ -32,17 +33,17 @@ function loadmnist(batchsize, split)
mnist = MNIST(; split)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
gdev
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))[
:, :, 1:1, 1:128] |> gdev
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> gdev
y_train = onehot(labels_raw)[:, 1:128] |> gdev
y_train = batchview(y_train, batchsize)
return x_train, y_train
end
x_train, y_train = loadmnist(128, :train);
x_test, y_test = loadmnist(128, :test);
x_train, y_train = loadmnist(16, :train);
x_test, y_test = loadmnist(16, :test);
```

Construct the Lux Neural Network containing a DEQ layer.
Expand Down
57 changes: 0 additions & 57 deletions ext/DeepEquilibriumNetworksZygoteExt.jl

This file was deleted.

5 changes: 3 additions & 2 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DeepEquilibriumNetworks
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AutoFiniteDiff
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoZygote
using ChainRulesCore: ChainRulesCore
using CommonSolve: solve
using ConcreteStructs: @concrete
Expand All @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations
using FastClosures: @closure
using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer,
StatefulLuxLayer, WrappedFunction
using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using NNlib:
using Random: Random, AbstractRNG, randn!
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm,
NonlinearSolution, ODESolution, ODEFunction, ODEProblem,
Expand Down
107 changes: 25 additions & 82 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ const DEQ = DeepEquilibriumNetwork

ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType}

function Lux.initialstates(rng::AbstractRNG, deq::DEQ)
rng = Lux.replicate(rng)
function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ)
rng = LuxCore.replicate(rng)
randn(rng, 1)
return (; model=Lux.initialstates(rng, deq.model), fixed_depth=Val(0),
init=Lux.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
return (; model=LuxCore.initialstates(rng, deq.model), fixed_depth=Val(0),
init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng)
end

(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, __check_unrolled_mode(st))
Expand All @@ -79,10 +79,10 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth)

z_star, st_ = repeated_model((z, x), ps.model, st.model)
model = StatefulLuxLayer(deq.model, ps.model, st_)
model = StatefulLuxLayer{true}(deq.model, ps.model, st_)
resid = CRC.ignore_derivatives(z_star .- model((z_star, x)))

rng = Lux.replicate(st.rng)
rng = LuxCore.replicate(st.rng)
jac_loss = __estimate_jacobian_trace(
__getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)

Expand All @@ -97,7 +97,7 @@ end
function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
z, st = __get_initial_condition(deq, x, ps, st)

model = StatefulLuxLayer(deq.model, ps.model, st.model)
model = StatefulLuxLayer{true}(deq.model, ps.model, st.model)

dudt = @closure (u, p, t) -> begin
# The type-assert is needed because of an upstream Lux issue with type stability of
Expand All @@ -113,7 +113,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...)
z_star = __get_steady_state(sol)

rng = Lux.replicate(st.rng)
rng = LuxCore.replicate(st.rng)
jac_loss = __estimate_jacobian_trace(
__getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng)

Expand Down Expand Up @@ -144,17 +144,16 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
condition is set to `zero(x)`. If `missing`, the initial condition is set to
`WrappedFunction(zero)`. In other cases the initial condition is set to
`init(x, ps, st)`.
- `jacobian_regularization`: Must be one of `nothing`, `AutoFiniteDiff` or `AutoZygote`.
- `jacobian_regularization`: Must be one of `nothing`, `AutoForwardDiff`, `AutoFiniteDiff`
or `AutoZygote`.
- `problem_type`: Provides a way to simulate a Vanilla Neural ODE by setting the
`problem_type` to `ODEProblem`. By default, the problem type is set to
`SteadyStateProblem`.
- `kwargs`: Additional Parameters that are directly passed to `SciMLBase.solve`.
## Example
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
```jldoctest
julia> model = DeepEquilibriumNetwork(
Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
VCABM3(); verbose=false)
Expand All @@ -168,13 +167,12 @@ DeepEquilibriumNetwork(
) # Total: 8 parameters,
# plus 0 states.
julia> rng = Random.default_rng()
TaskLocalRNG()
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> model(ones(Float32, 2, 1), ps, st);
julia> size(first(model(ones(Float32, 2, 1), ps, st)))
(2, 1)
```
See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref),
Expand Down Expand Up @@ -234,84 +232,27 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
## Example
```julia
julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve
```jldoctest
julia> main_layers = (
Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh))
(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast))
Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));
julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
4×4 Matrix{LuxCore.AbstractExplicitLayer}:
NoOpLayer() … Dense(4 => 1, tanh_fast)
Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast)
Dense(2 => 4, tanh_fast) Dense(2 => 1, tanh_fast)
Dense(1 => 4, tanh_fast) NoOpLayer()
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];
julia> model = MultiScaleDeepEquilibriumNetwork(
main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)))
DeepEquilibriumNetwork(
model = MultiScaleInputLayer{scales = 4}(
model = Chain(
layer_1 = Parallel(
layer_1 = Parallel(
+
Dense(4 => 4, tanh_fast, bias=false), # 16 parameters
Dense(4 => 4, tanh_fast, bias=false), # 16 parameters
),
layer_2 = Dense(3 => 3, tanh_fast), # 12 parameters
layer_3 = Dense(2 => 2, tanh_fast), # 6 parameters
layer_4 = Dense(1 => 1, tanh_fast), # 2 parameters
),
layer_2 = BranchLayer(
layer_1 = Parallel(
+
NoOpLayer(),
Dense(3 => 4, tanh_fast), # 16 parameters
Dense(2 => 4, tanh_fast), # 12 parameters
Dense(1 => 4, tanh_fast), # 8 parameters
),
layer_2 = Parallel(
+
Dense(4 => 3, tanh_fast), # 15 parameters
NoOpLayer(),
Dense(2 => 3, tanh_fast), # 9 parameters
Dense(1 => 3, tanh_fast), # 6 parameters
),
layer_3 = Parallel(
+
Dense(4 => 2, tanh_fast), # 10 parameters
Dense(3 => 2, tanh_fast), # 8 parameters
NoOpLayer(),
Dense(1 => 2, tanh_fast), # 4 parameters
),
layer_4 = Parallel(
+
Dense(4 => 1, tanh_fast), # 5 parameters
Dense(3 => 1, tanh_fast), # 4 parameters
Dense(2 => 1, tanh_fast), # 3 parameters
NoOpLayer(),
),
),
),
),
init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Val{((4,), (3,), (2,), (1,))}}(DeepEquilibriumNetworks.__zeros_init, Val{((4,), (3,), (2,), (1,))}())),
) # Total: 152 parameters,
# plus 0 states.
main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));
julia> rng = Random.default_rng()
TaskLocalRNG()
julia> rng = Xoshiro(0);
julia> ps, st = Lux.setup(rng, model);
julia> x = rand(rng, Float32, 4, 12);
julia> model(x, ps, st);
julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
```
"""
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
Expand Down Expand Up @@ -390,7 +331,9 @@ end
function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N}
return MultiScaleInputLayer{N}
end
Lux.display_name(::MultiScaleInputLayer{N}) where {N} = "MultiScaleInputLayer{scales = $N}"
function LuxCore.display_name(::MultiScaleInputLayer{N}) where {N}
return "MultiScaleInputLayer{scales = $N}"
end

function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S}
return MultiScaleInputLayer{length(S)}(model, split_idxs, scales)
Expand All @@ -401,7 +344,7 @@ end
return quote
u, x = z
u_ = __split_and_reshape(u, m.split_idxs, m.scales)
u_res, st = Lux.apply(m.model, ($(inputs...),), ps, st)
u_res, st = LuxCore.apply(m.model, ($(inputs...),), ps, st)
return __flatten_vcat(u_res), st
end
end
Loading

0 comments on commit ce97811

Please sign in to comment.