Skip to content

Commit

Permalink
Merge pull request #100 from JuliaGNI/simplify_transformer_tests
Browse files Browse the repository at this point in the history
Simplify transformer tests
  • Loading branch information
michakraus authored Dec 20, 2023
2 parents 9b56781 + dd343b9 commit 9dbdb58
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 79 deletions.
2 changes: 1 addition & 1 deletion src/data_loader/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ With the optional argument:
The output of `optimize_for_one_epoch!` is the average loss over all batches of the epoch:
```math
output = \frac{1}{mathtt{steps\_per\_epoch}}\sum_{t=1}^mathtt{steps\_per\_epoch}loss(\theta^{(t-1)}).
output = \frac{1}{\mathtt{steps\_per\_epoch}}\sum_{t=1}^\mathtt{steps\_per\_epoch}loss(\theta^{(t-1)}).
```
This is done because any **reverse differentiation** routine always has two outputs: a pullback and the value of the function it is differentiating. In the case of zygote: `loss_value, pullback = Zygote.pullback(ps -> loss(ps), ps)` (if the loss only depends on the parameters).
"""
Expand Down
9 changes: 9 additions & 0 deletions src/layers/multi_head_attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,23 @@ function (d::MultiHeadAttention{M, M, Stiefel, Retraction, false})(x::AbstractAr
end

import ChainRules
"""
This has to be extended to tensors; you should probably do a PR in ChainRules for this.
"""
function ChainRules._adjoint_mat_pullback(y::AbstractArray{T, 3}, proj) where T
(NoTangent(), proj(tensor_transpose(y)))
end

"""
Extend `mat_tensor_mul` to a multiplication by the adjoint of an element of `StiefelManifold`.
"""
function mat_tensor_mul(Y::AT, x::AbstractArray{T, 3}) where {T, BT <: AbstractArray{T}, ST <: StiefelManifold{T, BT}, AT <: Adjoint{T, ST}}
mat_tensor_mul(Y.parent.A', x)
end

"""
Extend `mat_tensor_mul` to a multiplication by an element of `StiefelManifold`.
"""
function mat_tensor_mul(Y::StiefelManifold, x::AbstractArray{T, 3}) where T
mat_tensor_mul(Y.A, x)
end
4 changes: 2 additions & 2 deletions src/optimizers/bfgs_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ It stores an array for the previous time step `B` and the inverse of the Hessian
It is important to note that setting up this cache already requires a derivative! This is not the case for the other optimizers.
"""
struct BFGSCache{T, AT<:AbstractArray{T}} <: AbstractCache
struct BFGSCache{T, AT<:AbstractArray{T}} <: AbstractCache{T}
B::AT
S::AT
H::AbstractMatrix{T}
Expand All @@ -19,7 +19,7 @@ In order to initialize `BGGSCache` we first need gradient information. This is w
NOTE: we may not need this.
"""
struct BFGSDummyCache{T, AT<:AbstractArray{T}} <: AbstractCache
struct BFGSDummyCache{T, AT<:AbstractArray{T}} <: AbstractCache{T}
function BFGSDummyCache(B::AbstractArray)
new{eltype(B), typeof(zero(B))}()
end
Expand Down
15 changes: 5 additions & 10 deletions src/optimizers/optimizer_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,32 @@ AbstractCache has subtypes:
All of them can be initialized with providing an array (also supporting manifold types).
"""
abstract type AbstractCache end
abstract type AbstractCache{T} end

#############################################################################
# All the definitions of the caches

struct AdamCache{T, AT <: AbstractArray{T}} <: AbstractCache
struct AdamCache{T, AT <: AbstractArray{T}} <: AbstractCache{T}
B₁::AT
B₂::AT
function AdamCache(Y::AbstractArray)
new{eltype(Y), typeof(zero(Y))}(zero(Y), zero(Y))
end
end

struct MomentumCache{T, AT <: AbstractArray{T}} <:AbstractCache
struct MomentumCache{T, AT <: AbstractArray{T}} <:AbstractCache{T}
B::AT
function MomentumCache(Y::AbstractArray)
new{eltype(Y), typeof(zero(Y))}(zero(Y))
end
end

struct GradientCache <: AbstractCache end
GradientCache(::AbstractArray) = GradientCache()
struct GradientCache{T} <: AbstractCache{T} end
GradientCache(::AbstractArray{T}) where T = GradientCache{T}()

#############################################################################
# All the setup_cache functions

# I don't really understand what we need these for ???
# setup_adam_cache(B::AbstractArray) = reshape([setup_adam_cache(b) for b in B], size(B))
# setup_momentum_cache(B::AbstractArray) = reshape([setup_momentum_cache(b) for b in B], size(B))
# setup_gradient_cache(B::AbstractArray) = reshape([setup_gradient_cache(b) for b in B], size(B))

setup_adam_cache(ps::NamedTuple) = apply_toNT(setup_adam_cache, ps)
setup_momentum_cache(ps::NamedTuple) = apply_toNT(setup_momentum_cache, ps)
setup_gradient_cache(ps::NamedTuple) = apply_toNT(setup_gradient_cache, ps)
Expand Down
6 changes: 6 additions & 0 deletions test/arrays/array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ function stiefel_lie_alg_add_sub_test(N, n)
@test all(abs.(projection(W₁ - W₂) .- S₄) .< 1e-10)
end

function stiefel_lie_alg_vectorization_test(N, n; T=Float32)
A = rand(StiefelLieAlgHorMatrix{T}, N, n)
@test isapprox(StiefelLieAlgHorMatrix(vec(A), N, n), A)
end

# TODO: tests for ADAM functions

# test everything for different n & N values
Expand All @@ -96,4 +101,5 @@ for (N, n) ∈ zip(N_vec, n_vec)
skew_mat_mul_test2(N)
stiefel_proj_test(N,n)
stiefel_lie_alg_add_sub_test(N,n)
stiefel_lie_alg_vectorization_test(N, n)
end
16 changes: 9 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ using SafeTestsets
@safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end
@safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end
@safetestset "ResNet " begin include("layers/resnet_tests.jl") end
@safetestset "Transformer Networks #1 " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end
@safetestset "Transformer Networks #2 " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end
@safetestset "Transformer Networks #3 " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end
@safetestset "Transformer Networks #4 " begin include("transformer_related/transformer_setup.jl") end
@safetestset "Transformer Networks #5 " begin include("transformer_related/transformer_application.jl") end
@safetestset "Transformer Networks #6 " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Transformer Networks #7 " begin include("transformer_related/transformer_optimizer.jl") end

# transformer-related tests
@safetestset "Test setup of MultiHeadAttention layer Stiefel weights " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end
@safetestset "Test geodesic and Cayley retr for the MultiHeadAttention layer w/ St weights " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end
@safetestset "Test the correct setup of the various optimizer caches for MultiHeadAttention " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end
@safetestset "Check if the transformer can be applied to a tensor. " begin include("transformer_related/transformer_application.jl") end
@safetestset "Check if the gradient/pullback of MultiHeadAttention changes type in St case " begin include("transformer_related/transformer_gradient.jl") end
@safetestset "Check if the optimization_step! changes the parameters of the transformer " begin include("transformer_related/transformer_optimizer.jl") end

@safetestset "Attention layer #1 " begin include("attention_layer/attention_setup.jl") end
@safetestset "(MultiHead)Attention " begin include("attention_layer/apply_multi_head_attention.jl") end
@safetestset "Classification layer " begin include("layers/classification.jl") end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,10 @@ using GeometricMachineLearning, Test

import Lux, Random, LinearAlgebra

dim = 64
n_heads = 8
Dₕ = dim÷8
tol = eps(Float32)

model = Chain(MultiHeadAttention(dim, n_heads), MultiHeadAttention(dim, n_heads, Stiefel=true))
ps = initialparameters(CPU(), Float32, model)

o₁ = Optimizer(AdamOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(GradientOptimizer(), ps)

function check_adam_cache(C::AbstractCache)
@doc raw"""
This checks if the Adam cache was set up in the right way
"""
function check_adam_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T
@test typeof(C) <: AdamCache
@test propertynames(C) == (:B₁, :B₂)
@test typeof(C.B₁) <: StiefelLieAlgHorMatrix
Expand All @@ -24,20 +15,44 @@ function check_adam_cache(C::AbstractCache)
end
check_adam_cache(B::NamedTuple) = apply_toNT(check_adam_cache, B)

function check_momentum_cache(C::AbstractCache)
@doc raw"""
This checks if the momentum cache was set up in the right way
"""
function check_momentum_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T
@test typeof(C) <: MomentumCache
@test propertynames(C) == (:B,)
@test typeof(C.B) <: StiefelLieAlgHorMatrix
@test LinearAlgebra.norm(C.B) < tol
end
check_momentum_cache(B::NamedTuple) = apply_toNT(check_momentum_cache, B)

function check_gradient_cache(C::AbstractCache)
@doc raw"""
This checks if the gradient cache was set up in the right way
"""
function check_gradient_cache(C::AbstractCache{T}) where T
@test typeof(C) <: GradientCache
@test propertynames(C) == ()
end
check_gradient_cache(B::NamedTuple) = apply_toNT(check_gradient_cache, B)

check_adam_cache(o₁.cache[2])
check_momentum_cache(o₂.cache[2])
check_gradient_cache(o₃.cache[2])
@doc raw"""
This checks if all the caches are set up in the right way for the `MultiHeadAttention` layer with Stiefel weights.
TODO:
- [ ] `BFGSOptimizer` !!
"""
function test_cache_setups_for_optimizer_for_multihead_attention_layer(T::Type, dim::Int, n_heads::Int)
@assert dim % n_heads == 0
model = MultiHeadAttention(dim, n_heads, Stiefel=true)
ps = initialparameters(CPU(), T, model)

o₁ = Optimizer(AdamOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(GradientOptimizer(), ps)

check_adam_cache(o₁.cache)
check_momentum_cache(o₂.cache)
check_gradient_cache(o₃.cache)
end

test_cache_setups_for_optimizer_for_multihead_attention_layer(Float32, 64, 8)
51 changes: 25 additions & 26 deletions test/transformer_related/multi_head_attention_stiefel_retraction.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,44 @@
"""
This is a test for that checks if the retractions (geodesic and Cayley for now) map from StiefelLieAlgHorMatrix to StiefelManifold when used with MultiHeadAttention.
"""

import Random, Test, Lux, LinearAlgebra, KernelAbstractions

using GeometricMachineLearning, Test
using GeometricMachineLearning: geodesic
using GeometricMachineLearning: cayley
using GeometricMachineLearning: init_optimizer_cache

dim = 64
n_heads = 8
Dₕ = dim÷8
tol = eps(Float32)
T = Float32
backend = KernelAbstractions.CPU()

model = MultiHeadAttention(dim, n_heads, Stiefel=true)

ps = initialparameters(backend, T, model)

cache = init_optimizer_cache(MomentumOptimizer(), ps)

E = StiefelProjection(dim, Dₕ, T)
function check_retraction_geodesic(A::AbstractMatrix)
@doc raw"""
This function computes the geodesic retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`.
"""
function check_retraction_geodesic(A::AbstractMatrix{T}, tol=eps(T)) where T
A_retracted = geodesic(A)
@test typeof(A_retracted) <: StiefelManifold
@test LinearAlgebra.norm(A_retracted - E) < tol
@test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol
end
check_retraction_geodesic(cache::NamedTuple) = apply_toNT(check_retraction_geodesic, cache)
check_retraction_geodesic(B::MomentumCache) = check_retraction_geodesic(B.B)

check_retraction_geodesic(cache)

E = StiefelProjection(dim, Dₕ)
function check_retraction_cayley(A::AbstractMatrix)
@doc raw"""
This function computes the cayley retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`.
"""
function check_retraction_cayley(A::AbstractMatrix{T}, tol=eps(T)) where T
A_retracted = cayley(A)
@test typeof(A_retracted) <: StiefelManifold
@test LinearAlgebra.norm(A_retracted - E) < tol
@test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol
end
check_retraction_cayley(cache::NamedTuple) = apply_toNT(check_retraction_cayley, cache)
check_retraction_cayley(B::MomentumCache) = check_retraction_cayley(B.B)

check_retraction_cayley(cache)
@doc raw"""
This is a test for that checks if the retractions (geodesic and Cayley for now) map from `StiefelLieAlgHorMatrix` to `StiefelManifold` when used with `MultiHeadAttention`.
"""
function test_multi_head_attention_retraction(T::Type, dim, n_heads, tol=eps(T), backend=KernelAbstractions.CPU())
model = MultiHeadAttention(dim, n_heads, Stiefel=true)

ps = initialparameters(backend, T, model)
cache = init_optimizer_cache(MomentumOptimizer(), ps)

check_retraction_geodesic(cache)

check_retraction_cayley(cache)
end

test_multi_head_attention_retraction(Float32, 64, 8)
33 changes: 22 additions & 11 deletions test/transformer_related/multi_head_attention_stiefel_setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@ import Random, Test, Lux, LinearAlgebra, KernelAbstractions
using GeometricMachineLearning, Test
using GeometricMachineLearning: init_optimizer_cache

T = Float32
model = MultiHeadAttention(64, 8, Stiefel=true)
ps = initialparameters(KernelAbstractions.CPU(), T, model)
tol = 10*eps(T)

function check_setup(A::AbstractMatrix)
@doc raw"""
This checks for an arbitrary matrix ``A\in\mathbb{R}^{N\times{}n}`` if ``A\in{}St(n,N)``.
"""
function check_setup(A::AbstractMatrix{T}, tol=T(10)*eps(T)) where T
@test typeof(A) <: StiefelManifold
@test check(A) < tol
end
check_setup(ps::NamedTuple) = apply_toNT(check_setup, ps)
check_setup(ps)

######## check if the gradients are set up the correct way
function check_grad_setup(B::AbstractMatrix)
@doc raw"""
This checks for an arbitrary matrix ``B\in\mathbb{R}^{N\times{}N}`` if ``B\in\mathfrak{g}^\mathrm{hor}``.
"""
function check_grad_setup(B::AbstractMatrix{T}, tol=T(10)*eps(T)) where T
@test typeof(B) <: StiefelLieAlgHorMatrix
@test LinearAlgebra.norm(B) < tol
end
check_grad_setup(gx::NamedTuple) = apply_toNT(check_grad_setup, gx)
check_grad_setup(B::MomentumCache) = check_grad_setup(B.B)

gx = init_optimizer_cache(MomentumOptimizer(), ps)
check_grad_setup(gx)
@doc raw"""
Check if `initialparameters` and `init_optimizer_cache` do the right thing for `MultiHeadAttentionLayer`.
"""
function check_multi_head_attention_stiefel_setup(T::Type, N::Int, n::Int)
model = MultiHeadAttention(N, n, Stiefel=true)
ps = initialparameters(KernelAbstractions.CPU(), T, model)

check_setup(ps)

gx = init_optimizer_cache(MomentumOptimizer(), ps)
check_grad_setup(gx)
end

check_multi_head_attention_stiefel_setup(Float32, 64, 8)
9 changes: 6 additions & 3 deletions test/transformer_related/transformer_application.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning

@doc raw"""
This tests if the size of the input array is kept constant when fed into the transformer (for a matrix and a tensor)
"""
function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim))
model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
Expand All @@ -8,11 +11,11 @@ function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_si
ps₂ = initialparameters(KernelAbstractions.CPU(), T, model₂)

input₁ = rand(T, dim, seq_length, batch_size)
input₂ = rand(T, dim, seq_length, batch_size)
input₂ = rand(T, dim, seq_length)

@test size(model₁(input₁, ps₁)) == size(input₁)
@test size(model₁(input₂, ps₁)) == size(input₂)
@test size(model₂(input₁, ps₂)) == size(input₁)
@test size(model₁(input₂, ps₁)) == size(input₂)
@test size(model₂(input₂, ps₂)) == size(input₂)
end

Expand Down
3 changes: 3 additions & 0 deletions test/transformer_related/transformer_gradient.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra

@doc raw"""
This checks if the gradients of the transformer change the type in case of the Stiefel manifold, and checks if they stay the same in the case of regular weights.
"""
function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim))
model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
Expand Down
9 changes: 8 additions & 1 deletion test/transformer_related/transformer_optimizer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra

@doc raw"""
This function tests if the `GradientOptimzier`, `MomentumOptimizer`, `AdamOptimizer` and `BFGSOptimizer` act on the neural network weights via `optimization_step!`.
"""
function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10)
model = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim))
model = Transformer(dim, n_heads, L, Stiefel=true)
Expand All @@ -14,15 +17,19 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=
o₁ = Optimizer(GradientOptimizer(), ps)
o₂ = Optimizer(MomentumOptimizer(), ps)
o₃ = Optimizer(AdamOptimizer(), ps)
o₄ = Optimizer(BFGSOptimizer(), ps)

ps₁ = deepcopy(ps)
ps₂ = deepcopy(ps)
ps₃ = deepcopy(ps)
ps₄ = deepcopy(ps)

optimization_step!(o₁, model, ps₁, dx)
optimization_step!(o₂, model, ps₂, dx)
optimization_step!(o₃, model, ps₃, dx)
@test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps)
optimization_step!(o₄, model, ps₄, dx)
@test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps₄) == typeof(ps)
@test ps₁[1].PQ.head_1 ps₂[1].PQ.head_1 ps₃[1].PQ.head_1 ps₄[1].PQ.head_1 ps[1].PQ.head_1
end

transformer_gradient_test(Float32, 10, 5, 4)
Loading

0 comments on commit 9dbdb58

Please sign in to comment.