Skip to content

Commit

Permalink
Refactored all the transformer tests and gave each of them a name.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Dec 14, 2023
1 parent 9002452 commit 4c82474
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 67 deletions.
17 changes: 9 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@ using SafeTestsets

@safetestset "Check parameterlength " begin include("parameterlength/check_parameterlengths.jl") end
@safetestset "Arrays #1 " begin include("arrays/array_tests.jl") end
@safetestset "Arrays #2 " begin include("arrays/array_tests_old.jl") end
@safetestset "Manifolds (Grassmann): " begin include("manifolds/grassmann_manifold.jl") end
@safetestset "Gradient Layer " begin include("layers/gradient_layer_tests.jl") end
@safetestset "Test symplecticity of upscaling layer " begin include("layers/sympnet_layers_test.jl") end
@safetestset "Hamiltonian Neural Network " begin include("hamiltonian_neural_network_tests.jl") end
@safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end
@safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end
@safetestset "ResNet " begin include("layers/resnet_tests.jl") end
@safetestset "Transformer Networks #1 " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end
@safetestset "Transformer Networks #2 " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end
@safetestset "Transformer Networks #3 " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end
@safetestset "Transformer Networks #4 " begin include("transformer_related/transformer_setup.jl") end
@safetestset "Transformer Networks #5 " begin include("transformer_related/transformer_application.jl") end
@safetestset "Transformer Networks #6 " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Transformer Networks #7 " begin include("transformer_related/transformer_optimizer.jl") end

# transformer-related tests
@safetestset "Test setup of MultiHeadAttention layer Stiefel weights " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end
@safetestset "Test geodesic and Cayley retr for the MultiHeadAttention layer w/ St weights " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end
@safetestset "Test the correct setup of the various optimizer caches for MultiHeadAttention " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end
@safetestset "Check if the transformer can be applied to a tensor. " begin include("transformer_related/transformer_application.jl") end
@safetestset "Check if the gradient/pullback of MultiHeadAttention changes type in St case " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Check if the optimization_step! changes the parameters of the transformer " begin include("transformer_related/transformer_optimizer.jl") end

@safetestset "Attention layer #1 " begin include("attention_layer/attention_setup.jl") end
@safetestset "(MultiHead)Attention " begin include("attention_layer/apply_multi_head_attention.jl") end
@safetestset "Classification layer " begin include("layers/classification.jl") end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,10 @@ using GeometricMachineLearning, Test

import Lux, Random, LinearAlgebra

dim = 64
n_heads = 8
Dₕ = dim÷8
tol = eps(Float32)

model = Chain(MultiHeadAttention(dim, n_heads), MultiHeadAttention(dim, n_heads, Stiefel=true))
ps = initialparameters(CPU(), Float32, model)

o₁ = Optimizer(AdamOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(GradientOptimizer(), ps)

function check_adam_cache(C::AbstractCache)
@doc raw"""
This checks if the Adam cache was set up in the right way
"""
function check_adam_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T
@test typeof(C) <: AdamCache
@test propertynames(C) == (:B₁, :B₂)
@test typeof(C.B₁) <: StiefelLieAlgHorMatrix
Expand All @@ -24,20 +15,44 @@ function check_adam_cache(C::AbstractCache)
end
check_adam_cache(B::NamedTuple) = apply_toNT(check_adam_cache, B)

function check_momentum_cache(C::AbstractCache)
@doc raw"""
This checks if the momentum cache was set up in the right way
"""
function check_momentum_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T
@test typeof(C) <: MomentumCache
@test propertynames(C) == (:B,)
@test typeof(C.B) <: StiefelLieAlgHorMatrix
@test LinearAlgebra.norm(C.B) < tol
end
check_momentum_cache(B::NamedTuple) = apply_toNT(check_momentum_cache, B)

function check_gradient_cache(C::AbstractCache)
@doc raw"""
This checks if the gradient cache was set up in the right way
"""
function check_gradient_cache(C::AbstractCache{T}) where T
@test typeof(C) <: GradientCache
@test propertynames(C) == ()
end
check_gradient_cache(B::NamedTuple) = apply_toNT(check_gradient_cache, B)

check_adam_cache(o₁.cache[2])
check_momentum_cache(o₂.cache[2])
check_gradient_cache(o₃.cache[2])
@doc raw"""
This checks if all the caches are set up in the right way for the `MultiHeadAttention` layer with Stiefel weights.
TODO:
- [ ] `BFGSOptimizer` !!
"""
function test_cache_setups_for_optimizer_for_multihead_attention_layer(T::Type, dim::Int, n_heads::Int)
@assert dim % n_heads == 0
model = MultiHeadAttention(dim, n_heads, Stiefel=true)
ps = initialparameters(CPU(), T, model)

o₁ = Optimizer(AdamOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(GradientOptimizer(), ps)

check_adam_cache(o₁.cache)
check_momentum_cache(o₂.cache)
check_gradient_cache(o₃.cache)
end

test_cache_setups_for_optimizer_for_multihead_attention_layer(Float32, 64, 8)
51 changes: 25 additions & 26 deletions test/transformer_related/multi_head_attention_stiefel_retraction.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,44 @@
"""
This is a test for that checks if the retractions (geodesic and Cayley for now) map from StiefelLieAlgHorMatrix to StiefelManifold when used with MultiHeadAttention.
"""

import Random, Test, Lux, LinearAlgebra, KernelAbstractions

using GeometricMachineLearning, Test
using GeometricMachineLearning: geodesic
using GeometricMachineLearning: cayley
using GeometricMachineLearning: init_optimizer_cache

dim = 64
n_heads = 8
Dₕ = dim÷8
tol = eps(Float32)
T = Float32
backend = KernelAbstractions.CPU()

model = MultiHeadAttention(dim, n_heads, Stiefel=true)

ps = initialparameters(backend, T, model)

cache = init_optimizer_cache(MomentumOptimizer(), ps)

E = StiefelProjection(dim, Dₕ, T)
function check_retraction_geodesic(A::AbstractMatrix)
@doc raw"""
This function computes the geodesic retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`.
"""
function check_retraction_geodesic(A::AbstractMatrix{T}, tol=eps(T)) where T
A_retracted = geodesic(A)
@test typeof(A_retracted) <: StiefelManifold
@test LinearAlgebra.norm(A_retracted - E) < tol
@test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol
end
check_retraction_geodesic(cache::NamedTuple) = apply_toNT(check_retraction_geodesic, cache)
check_retraction_geodesic(B::MomentumCache) = check_retraction_geodesic(B.B)

check_retraction_geodesic(cache)

E = StiefelProjection(dim, Dₕ)
function check_retraction_cayley(A::AbstractMatrix)
@doc raw"""
This function computes the cayley retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`.
"""
function check_retraction_cayley(A::AbstractMatrix{T}, tol=eps(T)) where T
A_retracted = cayley(A)
@test typeof(A_retracted) <: StiefelManifold
@test LinearAlgebra.norm(A_retracted - E) < tol
@test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol
end
check_retraction_cayley(cache::NamedTuple) = apply_toNT(check_retraction_cayley, cache)
check_retraction_cayley(B::MomentumCache) = check_retraction_cayley(B.B)

check_retraction_cayley(cache)
@doc raw"""
This is a test for that checks if the retractions (geodesic and Cayley for now) map from `StiefelLieAlgHorMatrix` to `StiefelManifold` when used with `MultiHeadAttention`.
"""
function test_multi_head_attention_retraction(T::Type, dim, n_heads, tol=eps(T), backend=KernelAbstractions.CPU())
model = MultiHeadAttention(dim, n_heads, Stiefel=true)

ps = initialparameters(backend, T, model)
cache = init_optimizer_cache(MomentumOptimizer(), ps)

check_retraction_geodesic(cache)

check_retraction_cayley(cache)
end

test_multi_head_attention_retraction(Float32, 64, 8)
33 changes: 22 additions & 11 deletions test/transformer_related/multi_head_attention_stiefel_setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@ import Random, Test, Lux, LinearAlgebra, KernelAbstractions
using GeometricMachineLearning, Test
using GeometricMachineLearning: init_optimizer_cache

T = Float32
model = MultiHeadAttention(64, 8, Stiefel=true)
ps = initialparameters(KernelAbstractions.CPU(), T, model)
tol = 10*eps(T)

function check_setup(A::AbstractMatrix)
@doc raw"""
This checks for an arbitrary matrix ``A\in\mathbb{R}^{N\times{}n}`` if ``A\in{}St(n,N)``.
"""
function check_setup(A::AbstractMatrix{T}, tol=T(10)*eps(T)) where T
@test typeof(A) <: StiefelManifold
@test check(A) < tol
end
check_setup(ps::NamedTuple) = apply_toNT(check_setup, ps)
check_setup(ps)

######## check if the gradients are set up the correct way
function check_grad_setup(B::AbstractMatrix)
@doc raw"""
This checks for an arbitrary matrix ``B\in\mathbb{R}^{N\times{}N}`` if ``B\in\mathfrak{g}^\mathrm{hor}``.
"""
function check_grad_setup(B::AbstractMatrix{T}, tol=T(10)*eps(T)) where T
@test typeof(B) <: StiefelLieAlgHorMatrix
@test LinearAlgebra.norm(B) < tol
end
check_grad_setup(gx::NamedTuple) = apply_toNT(check_grad_setup, gx)
check_grad_setup(B::MomentumCache) = check_grad_setup(B.B)

gx = init_optimizer_cache(MomentumOptimizer(), ps)
check_grad_setup(gx)
@doc raw"""
Check if `initialparameters` and `init_optimizer_cache` do the right thing for `MultiHeadAttentionLayer`.
"""
function check_multi_head_attention_stiefel_setup(T::Type, N::Int, n::Int)
model = MultiHeadAttention(N, n, Stiefel=true)
ps = initialparameters(KernelAbstractions.CPU(), T, model)

check_setup(ps)

gx = init_optimizer_cache(MomentumOptimizer(), ps)
check_grad_setup(gx)
end

check_multi_head_attention_stiefel_setup(Float32, 64, 8)
9 changes: 6 additions & 3 deletions test/transformer_related/transformer_application.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning

@doc raw"""
This tests if the size of the input array is kept constant when fed into the transformer (for a matrix and a tensor)
"""
function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim))
model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
Expand All @@ -8,11 +11,11 @@ function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_si
ps₂ = initialparameters(KernelAbstractions.CPU(), T, model₂)

input₁ = rand(T, dim, seq_length, batch_size)
input₂ = rand(T, dim, seq_length, batch_size)
input₂ = rand(T, dim, seq_length)

@test size(model₁(input₁, ps₁)) == size(input₁)
@test size(model₁(input₂, ps₁)) == size(input₂)
@test size(model₂(input₁, ps₂)) == size(input₁)
@test size(model₁(input₂, ps₁)) == size(input₂)
@test size(model₂(input₂, ps₂)) == size(input₂)
end

Expand Down
3 changes: 3 additions & 0 deletions test/transformer_related/transformer_gradient.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra

@doc raw"""
This checks if the gradients of the transformer change the type in case of the Stiefel manifold, and checks if they stay the same in the case of regular weights.
"""
function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim))
model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
Expand Down
9 changes: 8 additions & 1 deletion test/transformer_related/transformer_optimizer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra

@doc raw"""
This function tests if the `GradientOptimzier`, `MomentumOptimizer`, `AdamOptimizer` and `BFGSOptimizer` act on the neural network weights via `optimization_step!`.
"""
function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
model = Transformer(dim, n_heads, L, Stiefel=true)
Expand All @@ -14,15 +17,19 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=
o₁ = Optimizer(GradientOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(AdamOptimizer(), ps)
o₄ = Optimizer(BFGSOptimizer(), ps)

ps₁ = deepcopy(ps)
ps₂ = deepcopy(ps)
ps₃ = deepcopy(ps)
ps₄ = deepcopy(ps)

optimization_step!(o₁, model, ps₁, dx)
optimization_step!(o₂, model, ps₂, dx)
optimization_step!(o₃, model, ps₃, dx)
@test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps)
optimization_step!(o₄, model, ps₄, dx)
@test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps₄) == typeof(ps)
@test ps₁[1].PQ.head_1 ps₂[1].PQ.head_1 ps₃[1].PQ.head_1 ps₄[1].PQ.head_1 ps[1].PQ.head_1
end

transformer_gradient_test(Float32, 10, 5, 4)
3 changes: 3 additions & 0 deletions test/transformer_related/transformer_setup.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning

@doc raw"""
This function tests the setup of the transformer with Stiefel weights.
"""
function transformer_setup_test(dim, n_heads, L, T)
model = Transformer(dim, n_heads, L, Stiefel=true)
ps = initialparameters(KernelAbstractions.CPU(), T, model)
Expand Down

0 comments on commit 4c82474

Please sign in to comment.