From cb10da6be023a1b8958b4fe8e4c87802685df510 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 19 Dec 2023 15:10:09 -0500 Subject: [PATCH] Finish the basic tutorial --- Project.toml | 2 +- docs/Project.toml | 4 +- docs/src/tutorials/basic_mnist_deq.md | 160 +++++++++++++++++++++++++- src/layers.jl | 6 +- 4 files changed, 162 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 90d22b4c..531a0285 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ ConcreteStructs = "0.2" ConstructionBase = "1" DiffEqBase = "6.119" LinearAlgebra = "1" -Lux = "0.5.7" +Lux = "0.5.11" Random = "1" SciMLBase = "2" SciMLSensitivity = "7.43" diff --git a/docs/Project.toml b/docs/Project.toml index 75f8f324..eb03c9e6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,13 +2,13 @@ DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" -OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 22b6c9cf..99a10ad1 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -4,9 +4,9 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack ```@example basic_mnist_deq using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimization, OptimizationOptimisers, LuxCUDA + Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve using MLDatasets: MNIST -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs +using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview CUDA.allowscalar(false) ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -27,9 +27,9 @@ function onehot(labels_raw) return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) end -function loadmnist(batchsize) +function loadmnist(batchsize, split) # Load MNIST - mnist = MNIST(; split=:train) + mnist = MNIST(; split) imgs, labels_raw = mnist.features, mnist.targets # Process images into (H,W,C,BS) batches x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> @@ -40,4 +40,156 @@ function loadmnist(batchsize) y_train = batchview(y_train, batchsize) return x_train, y_train end + +x_train, y_train = loadmnist(128, :train); +x_test, y_test = loadmnist(128, :test); ``` + +Construct the Lux Neural Network containing a DEQ layer. + +```@example basic_mnist_deq +function construct_model(solver; model_type::Symbol = :deq) + down = Chain(Conv((3, 3), 1 => 64, gelu; stride = 1), GroupNorm(64, 64), + Conv((4, 4), 64 => 64; stride = 2, pad = 1)) + + # The input layer of the DEQ + deq_model = Chain(Parallel(+, + Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad()), + Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad())), + Conv((3, 3), 64 => 64, tanh; stride = 1, pad = SamePad())) + + if model_type === :skipdeq + init = Conv((3, 3), 64 => 64, gelu; stride = 1, pad = SamePad()) + elseif model_type === :regdeq + init = nothing + else + init = missing + end + + deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose = false, + linsolve_kwargs = (; maxiters = 10)) + + classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), + Dense(64, 10)) + + model = Chain(; down, deq, classifier) + + # For NVIDIA GPUs this directly generates the parameters on the GPU + rng = Random.default_rng() |> gdev + ps, st = Lux.setup(rng, model) + + # Warmup the forward and backward passes + x = randn(rng, Float32, 28, 28, 1, 128) + y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev + + model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st) + @info "warming up forward pass" + logitcrossentropy(model_, x, ps, y) + @info "warming up backward pass" + Zygote.gradient(logitcrossentropy, model_, x, ps, y) + @info "warmup complete" + + return model, ps, st +end +``` + +Define some helper functions to train the model. + +```@example basic_mnist_deq +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) +function logitcrossentropy(model, x, ps, y) + l1 = logitcrossentropy(model(x, ps), y) + # Add in some regularization + l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0) + return l1 + 10.0 * l2 +end + +classify(x) = argmax.(eachcol(x)) + +function accuracy(model, data, ps, st) + total_correct, total = 0, 0 + st = Lux.testmode(st) + model = Lux.Experimental.StatefulLuxLayer(model, ps, st) + for (x, y) in data + target_class = classify(cdev(y)) + predicted_class = classify(cdev(model(x))) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end + +function train_model(solver, model_type; data_train = zip(x_train, y_train), + data_test = zip(x_test, y_test)) + model, ps, st = construct_model(solver; model_type) + model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st) + + @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))" + + opt_st = Optimisers.setup(Adam(0.001), ps) + + acc = accuracy(model, data_test, ps, st) * 100 + @info "Starting Accuracy: $(acc)" + + # = Uncomment these lines to enavle pretraining. See what happens + @info "Pretrain with unrolling to a depth of 5" + st = Lux.update_state(st, :fixed_depth, Val(5)) + model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + + for (i, (x, y)) in enumerate(data_train) + res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) + Optimisers.update!(opt_st, ps, res.grad[3]) + if i % 50 == 1 + @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" + end + end + + acc = accuracy(model, data_test, ps, model_st.st) * 100 + @info "Pretraining complete. Accuracy: $(acc)" + # =# + + st = Lux.update_state(st, :fixed_depth, Val(0)) + model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st) + + for epoch in 1:3 + for (i, (x, y)) in enumerate(data_train) + res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) + Optimisers.update!(opt_st, ps, res.grad[3]) + if i % 50 == 1 + @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)" + end + end + + acc = accuracy(model, data_test, ps, model_st.st) * 100 + @info "Epoch: [$(epoch)/3] Accuracy: $(acc)" + end + + @info "Training complete." + println() + + return model, ps, st +end +``` + +Now we can train our model. First we will train a Discrete DEQ, which effectively means +pass in a root finding algorithm. Typically most packages lack good nonlinear solvers, +and end up using solvers like `Broyden`, but we can simply slap in any of the fancy solvers +from NonlinearSolve.jl. Here we will use Newton-Krylov Method: + +```@example basic_mnist_deq +train_model(NewtonRaphson(; linsolve = KrylovJL_GMRES()), :regdeq) +nothing # hide +``` + +We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()` +which tend to be quite fast for continuous Neural Network problems. + +```@example basic_mnist_deq +train_model(VCAB3(), :deq) +nothing # hide +``` + +This code is setup to allow playing around with different DEQ models. Try modifying the +`model_type` argument to `train_model` to `:skipdeq` or `:deq` to see how the model +behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl! +Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those. diff --git a/src/layers.jl b/src/layers.jl index fe9a3a4d..f9e56071 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -13,7 +13,7 @@ Stores the solution of a DeepEquilibriumNetwork and its variants. - `nfe`: Number of Function Evaluations - `original`: Original Internal Solution """ -@concrete struct DeepEquilibriumSolution +struct DeepEquilibriumSolution # This is intentionally left untyped to allow updating `st` z_star u0 residual @@ -85,7 +85,7 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) model, ps.model, zˢᵗᵃʳ, x, rng) solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, resid, zero(eltype(x)), - _unwrap_val(st.fixed_depth), nothing) + _unwrap_val(st.fixed_depth), jac_loss) res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)), __getproperty(deq.model, Val(:scales))) @@ -102,7 +102,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x)) alg = __normalize_alg(deq) sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, reltol=1e-3, - termination_condition=AbsNormTerminationMode(), maxiters=100, deq.kwargs...) + termination_condition=AbsNormTerminationMode(), maxiters=32, deq.kwargs...) zˢᵗᵃʳ = __get_steady_state(sol) rng = Lux.replicate(st.rng)