Skip to content

Commit

Permalink
Fixed script and adjusted to new syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Dec 20, 2023
1 parent e5ba56f commit fb6dcb9
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions docs/src/tutorials/mnist_tutorial.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# MNIST tutorial

This is a short tutorial that shows how we can use `GeometricMachineLearning` to build a vision transformer and apply it for MNIST, while also putting some of the weights on a manifold.
This is a short tutorial that shows how we can use `GeometricMachineLearning` to build a vision transformer and apply it for MNIST, while also putting some of the weights on a manifold. This is also the result presented in [brantner2023generalizing](@cite).

First, we need to import the relevant packages:

```julia
using GeometricMachineLearning, CUDA
import Zygote, MLDatasets
using GeometricMachineLearning, CUDA, Plots
import Zygote, MLDatasets, KernelAbstractions
```

In this example `Zygote` as an AD routine and we get the dataset from `MLDatasets`. First we need to load the data set, and put it on GPU (if you have one):
For the AD routine we here use the `GeometricMachineLearning` default and we get the dataset from `MLDatasets`. First we need to load the data set, and put it on GPU (if you have one):

```julia
train_x, train_y = MLDatasets.MNIST(split=:train)[:]
Expand All @@ -24,36 +24,49 @@ test_y = test_y |> cu

```julia
patch_length = 7
dl = DataLoader(train_x, train_y, batch_size=512, patch_length=patch_length)
dl_test = DataLoader(train_x, train_y, batch_size=length(y), patch_length=patch_length)
dl = DataLoader(train_x, train_y, patch_length=patch_length)
dl_test = DataLoader(train_x, train_y, patch_length=patch_length)
```

The second line in the above code snippet indicates that we use the entire data set as one "batch" when processing the test set. For training, the batch size was here set to 512.
Here `patch_length` indicates the size one patch has. One image in MNIST is of dimension ``28\times28``, this means that we decompose this into 16 ``(7\times7)`` images (also see [brantner2023generalizing](@cite)).

We next define the model with which we want to train:

```julia
ps = initialparameters(backend, eltype(dl.data), Ψᵉ)
model = ClassificationTransformer(dl, n_heads=n_heads, n_layers=n_layers, Stiefel=true)
```

optimizer_instance = Optimizer(o, ps)
Here we have chosen a `ClassificationTransformer`, i.e. a composition of a specific number of transformer layers composed with a classification layer. We also set the *Stiefel option* to `true`, i.e. we are optimizing on the Stiefel manifold.

println("initial test accuracy: ", accuracy(Ψᵉ, ps, dl_test), "\n")
We now have to initialize the neural network weights. This is done with the constructor for `NeuralNetwork`:

```julia
backend = KernelAbstractions.get_backend(dl)
T = eltype(dl)
nn = NeuralNetwork(model, backend, T)
```

And with this we can finally perform the training:

progress_object = Progress(n_training_steps; enabled=true)
```julia
# an instance of batch is needed for the optimizer
batch = Batch(batch_size)

optimizer_instance = Optimizer(AdamOptimizer(), nn)

loss_array = zeros(eltype(train_x), n_training_steps)
for i in 1:n_training_steps
redraw_batch!(dl)
# get rid of try catch statement. This softmax issue should be solvable!
loss_val, pb = try Zygote.pullback(ps -> loss(Ψᵉ, ps, dl), ps)
catch
loss_array[i] = loss_array[i-1]
continue
end
dp = pb(one(loss_val))[1]
# this prints the accuracy and is optional
println("initial test accuracy: ", accuracy(Ψᵉ, ps, dl_test), "\n")

optimization_step!(optimizer_instance, Ψᵉ, ps, dp)
ProgressMeter.next!(progress_object; showvalues = [(:TrainingLoss, loss_val)])
loss_array[i] = loss_val
end
loss_array = optimizer_instance(nn, dl, batch, n_epochs)

println("final test accuracy: ", accuracy(Ψᵉ, ps, dl_test), "\n")
```

It is instructive to play with `n_layers`, `n_epochs` and the Stiefel property.

```@bibliography
Pages = []
Canonical = false
brantner2023generalizing
```

0 comments on commit fb6dcb9

Please sign in to comment.