Skip to content

Commit

Permalink
Finish the basic tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 19, 2023
1 parent c4f0d10 commit cb10da6
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
LinearAlgebra = "1"
Lux = "0.5.7"
Lux = "0.5.11"
Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand Down
160 changes: 156 additions & 4 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimization, OptimizationOptimisers, LuxCUDA
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
Expand All @@ -27,9 +27,9 @@ function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end
function loadmnist(batchsize)
function loadmnist(batchsize, split)
# Load MNIST
mnist = MNIST(; split=:train)
mnist = MNIST(; split)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
Expand All @@ -40,4 +40,156 @@ function loadmnist(batchsize)
y_train = batchview(y_train, batchsize)
return x_train, y_train
end
x_train, y_train = loadmnist(128, :train);
x_test, y_test = loadmnist(128, :test);
```

Construct the Lux Neural Network containing a DEQ layer.

```@example basic_mnist_deq
function construct_model(solver; model_type::Symbol = :deq)
down = Chain(Conv((3, 3), 1 => 64, gelu; stride = 1), GroupNorm(64, 64),
Conv((4, 4), 64 => 64; stride = 2, pad = 1))
# The input layer of the DEQ
deq_model = Chain(Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad()),
Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad())),
Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad()))
if model_type === :skipdeq
init = Conv((3, 3), 64 => 64, gelu; stride = 1, pad = SamePad())
elseif model_type === :regdeq
init = nothing
else
init = missing
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose = false,
linsolve_kwargs = (; maxiters = 10))
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
Dense(64, 10))
model = Chain(; down, deq, classifier)
# For NVIDIA GPUs this directly generates the parameters on the GPU
rng = Random.default_rng() |> gdev
ps, st = Lux.setup(rng, model)
# Warmup the forward and backward passes
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@info "warming up forward pass"
logitcrossentropy(model_, x, ps, y)
@info "warming up backward pass"
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
@info "warmup complete"
return model, ps, st
end
```

Define some helper functions to train the model.

```@example basic_mnist_deq
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1))
function logitcrossentropy(model, x, ps, y)
l1 = logitcrossentropy(model(x, ps), y)
# Add in some regularization
l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
return l1 + 10.0 * l2
end
classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
function train_model(solver, model_type; data_train = zip(x_train, y_train),
data_test = zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
opt_st = Optimisers.setup(Adam(0.001), ps)
acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
# = Uncomment these lines to enavle pretraining. See what happens
@info "Pretrain with unrolling to a depth of 5"
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
# =#
st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
end
@info "Training complete."
println()
return model, ps, st
end
```

Now we can train our model. First we will train a Discrete DEQ, which effectively means
pass in a root finding algorithm. Typically most packages lack good nonlinear solvers,
and end up using solvers like `Broyden`, but we can simply slap in any of the fancy solvers
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:

```@example basic_mnist_deq
train_model(NewtonRaphson(; linsolve = KrylovJL_GMRES()), :regdeq)
nothing # hide
```

We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
which tend to be quite fast for continuous Neural Network problems.

```@example basic_mnist_deq
train_model(VCAB3(), :deq)
nothing # hide
```

This code is setup to allow playing around with different DEQ models. Try modifying the
`model_type` argument to `train_model` to `:skipdeq` or `:deq` to see how the model
behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl!
Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those.
6 changes: 3 additions & 3 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Stores the solution of a DeepEquilibriumNetwork and its variants.
- `nfe`: Number of Function Evaluations
- `original`: Original Internal Solution
"""
@concrete struct DeepEquilibriumSolution
struct DeepEquilibriumSolution # This is intentionally left untyped to allow updating `st`
z_star
u0
residual
Expand Down Expand Up @@ -85,7 +85,7 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
model, ps.model, zˢᵗᵃʳ, x, rng)

solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, resid, zero(eltype(x)),
_unwrap_val(st.fixed_depth), nothing)
_unwrap_val(st.fixed_depth), jac_loss)
res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)),
__getproperty(deq.model, Val(:scales)))

Expand All @@ -102,7 +102,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x))
alg = __normalize_alg(deq)
sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, reltol=1e-3,
termination_condition=AbsNormTerminationMode(), maxiters=100, deq.kwargs...)
termination_condition=AbsNormTerminationMode(), maxiters=32, deq.kwargs...)
zˢᵗᵃʳ = __get_steady_state(sol)

rng = Lux.replicate(st.rng)
Expand Down

0 comments on commit cb10da6

Please sign in to comment.