diff --git a/scripts/transformer_analysis.jl b/scripts/transformer_analysis.jl index 01e7e0f15..59439c734 100644 --- a/scripts/transformer_analysis.jl +++ b/scripts/transformer_analysis.jl @@ -14,8 +14,8 @@ GeometricMachineLearning.Chain(d::AbstractNeuralNetworks.AbstractExplicitLayer, image_dim = 28 patch_length = 7 transformer_dim = 49 -n_heads = 1 -n_layers = 16 +n_heads = 7 +n_layers = 1 number_of_patch = (image_dim÷patch_length)^2 batch_size = 2048 activation = softmax