Skip to content

Commit

Permalink
Finish the tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 19, 2023
1 parent 29951bd commit b67f684
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 7 deletions.
2 changes: 0 additions & 2 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ function train_model(solver, model_type; data_train=zip(x_train, y_train),
acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
# = Uncomment these lines to enavle pretraining. See what happens
@info "Pretrain with unrolling to a depth of 5"
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
Expand All @@ -146,7 +145,6 @@ function train_model(solver, model_type; data_train=zip(x_train, y_train),
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
# =#
st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
Expand Down
177 changes: 176 additions & 1 deletion docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,178 @@
# Modelling Equilibrium Models with Reduced State Size

This Tutorial is currently under preparation. Check back soon.
Sometimes we want don't want to solve a root finding problem with the full state size. This
will often be faster, since the size of the root finding problem is reduced. We will use the
same MNIST example as before, but this time we will use a reduced state size.

```@example reduced_dim_mnist
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
const cdev = cpu_device()
const gdev = gpu_device()
function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end
function loadmnist(batchsize, split)
# Load MNIST
mnist = MNIST(; split)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
gdev
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> gdev
y_train = batchview(y_train, batchsize)
return x_train, y_train
end
x_train, y_train = loadmnist(128, :train);
x_test, y_test = loadmnist(128, :test);
```

Now we will define the construct model function. Here we will use Dense Layers and
downsample the features using the `init` kwarg.

```@example reduced_dim_mnist
function construct_model(solver; model_type::Symbol=:regdeq)
down = Chain(FlattenLayer(), Dense(784 => 512, gelu))
# The input layer of the DEQ
deq_model = Chain(Parallel(+,
Dense(128 => 64, tanh), # Reduced dim of `128`
Dense(512 => 64, tanh)), # Original dim of `512`
Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128`
if model_type === :skipdeq
init = Dense(512 => 128, tanh)
elseif model_type === :regdeq
error(":regdeq is not supported for reduced dim models")
else
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
# we are only using Zygote so this is fine.
init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)),
false)))
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10))
classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10))
model = Chain(; down, deq, classifier)
# For NVIDIA GPUs this directly generates the parameters on the GPU
rng = Random.default_rng() |> gdev
ps, st = Lux.setup(rng, model)
# Warmup the forward and backward passes
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@info "warming up forward pass"
logitcrossentropy(model_, x, ps, y)
@info "warming up backward pass"
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
@info "warmup complete"
return model, ps, st
end
```

Define some helper functions to train the model.

```@example reduced_dim_mnist
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
function logitcrossentropy(model, x, ps, y)
l1 = logitcrossentropy(model(x, ps), y)
# Add in some regularization
l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
return l1 + 0.1f0 * l2
end
classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
function train_model(solver, model_type; data_train=zip(x_train, y_train),
data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
opt_st = Optimisers.setup(Adam(0.001), ps)
acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
@info "Pretrain with unrolling to a depth of 5"
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
end
@info "Training complete."
println()
return model, ps, st
end
```

Now we can train our model. We can't use `:regdeq` here currently, but we will support this
in the future.

```@example reduced_dim_mnist
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
nothing # hide
```

```@example reduced_dim_mnist
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
nothing # hide
```
6 changes: 2 additions & 4 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ end
z, st = model(x, ps, st)

opt_broken = solver isa NewtonRaphson ||
solver isa SimpleLimitedMemoryBroyden ||
jacobian_regularization isa AutoZygote
solver isa SimpleLimitedMemoryBroyden
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch

@test all(isfinite, z)
Expand Down Expand Up @@ -142,8 +141,7 @@ end
z_ = DEQs.__flatten_vcat(z)

opt_broken = solver isa NewtonRaphson ||
solver isa SimpleLimitedMemoryBroyden ||
jacobian_regularization isa AutoZygote
solver isa SimpleLimitedMemoryBroyden
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch

@test all(isfinite, z_)
Expand Down

0 comments on commit b67f684

Please sign in to comment.