diff --git a/src/data_loader/batch.jl b/src/data_loader/batch.jl index 10b434ffd..449fa05a8 100644 --- a/src/data_loader/batch.jl +++ b/src/data_loader/batch.jl @@ -64,7 +64,7 @@ With the optional argument: The output of `optimize_for_one_epoch!` is the average loss over all batches of the epoch: ```math -output = \frac{1}{mathtt{steps\_per\_epoch}}\sum_{t=1}^mathtt{steps\_per\_epoch}loss(\theta^{(t-1)}). +output = \frac{1}{\mathtt{steps\_per\_epoch}}\sum_{t=1}^\mathtt{steps\_per\_epoch}loss(\theta^{(t-1)}). ``` This is done because any **reverse differentiation** routine always has two outputs: a pullback and the value of the function it is differentiating. In the case of zygote: `loss_value, pullback = Zygote.pullback(ps -> loss(ps), ps)` (if the loss only depends on the parameters). """ diff --git a/src/layers/multi_head_attention.jl b/src/layers/multi_head_attention.jl index 05c3c53cc..e07820cb7 100644 --- a/src/layers/multi_head_attention.jl +++ b/src/layers/multi_head_attention.jl @@ -125,14 +125,23 @@ function (d::MultiHeadAttention{M, M, Stiefel, Retraction, false})(x::AbstractAr end import ChainRules +""" +This has to be extended to tensors; you should probably do a PR in ChainRules for this. +""" function ChainRules._adjoint_mat_pullback(y::AbstractArray{T, 3}, proj) where T (NoTangent(), proj(tensor_transpose(y))) end +""" +Extend `mat_tensor_mul` to a multiplication by the adjoint of an element of `StiefelManifold`. +""" function mat_tensor_mul(Y::AT, x::AbstractArray{T, 3}) where {T, BT <: AbstractArray{T}, ST <: StiefelManifold{T, BT}, AT <: Adjoint{T, ST}} mat_tensor_mul(Y.parent.A', x) end +""" +Extend `mat_tensor_mul` to a multiplication by an element of `StiefelManifold`. +""" function mat_tensor_mul(Y::StiefelManifold, x::AbstractArray{T, 3}) where T mat_tensor_mul(Y.A, x) end \ No newline at end of file diff --git a/src/optimizers/bfgs_cache.jl b/src/optimizers/bfgs_cache.jl index 26f1c0b96..96561f8be 100644 --- a/src/optimizers/bfgs_cache.jl +++ b/src/optimizers/bfgs_cache.jl @@ -5,7 +5,7 @@ It stores an array for the previous time step `B` and the inverse of the Hessian It is important to note that setting up this cache already requires a derivative! This is not the case for the other optimizers. """ -struct BFGSCache{T, AT<:AbstractArray{T}} <: AbstractCache +struct BFGSCache{T, AT<:AbstractArray{T}} <: AbstractCache{T} B::AT S::AT H::AbstractMatrix{T} @@ -19,7 +19,7 @@ In order to initialize `BGGSCache` we first need gradient information. This is w NOTE: we may not need this. """ -struct BFGSDummyCache{T, AT<:AbstractArray{T}} <: AbstractCache +struct BFGSDummyCache{T, AT<:AbstractArray{T}} <: AbstractCache{T} function BFGSDummyCache(B::AbstractArray) new{eltype(B), typeof(zero(B))}() end diff --git a/src/optimizers/optimizer_caches.jl b/src/optimizers/optimizer_caches.jl index b8a21cb0a..398eaa562 100644 --- a/src/optimizers/optimizer_caches.jl +++ b/src/optimizers/optimizer_caches.jl @@ -7,12 +7,12 @@ AbstractCache has subtypes: All of them can be initialized with providing an array (also supporting manifold types). """ -abstract type AbstractCache end +abstract type AbstractCache{T} end ############################################################################# # All the definitions of the caches -struct AdamCache{T, AT <: AbstractArray{T}} <: AbstractCache +struct AdamCache{T, AT <: AbstractArray{T}} <: AbstractCache{T} B₁::AT B₂::AT function AdamCache(Y::AbstractArray) @@ -20,24 +20,19 @@ struct AdamCache{T, AT <: AbstractArray{T}} <: AbstractCache end end -struct MomentumCache{T, AT <: AbstractArray{T}} <:AbstractCache +struct MomentumCache{T, AT <: AbstractArray{T}} <:AbstractCache{T} B::AT function MomentumCache(Y::AbstractArray) new{eltype(Y), typeof(zero(Y))}(zero(Y)) end end -struct GradientCache <: AbstractCache end -GradientCache(::AbstractArray) = GradientCache() +struct GradientCache{T} <: AbstractCache{T} end +GradientCache(::AbstractArray{T}) where T = GradientCache{T}() ############################################################################# # All the setup_cache functions -# I don't really understand what we need these for ??? -# setup_adam_cache(B::AbstractArray) = reshape([setup_adam_cache(b) for b in B], size(B)) -# setup_momentum_cache(B::AbstractArray) = reshape([setup_momentum_cache(b) for b in B], size(B)) -# setup_gradient_cache(B::AbstractArray) = reshape([setup_gradient_cache(b) for b in B], size(B)) - setup_adam_cache(ps::NamedTuple) = apply_toNT(setup_adam_cache, ps) setup_momentum_cache(ps::NamedTuple) = apply_toNT(setup_momentum_cache, ps) setup_gradient_cache(ps::NamedTuple) = apply_toNT(setup_gradient_cache, ps) diff --git a/test/arrays/array_tests.jl b/test/arrays/array_tests.jl index c0da5eba9..c2badb77b 100644 --- a/test/arrays/array_tests.jl +++ b/test/arrays/array_tests.jl @@ -76,6 +76,11 @@ function stiefel_lie_alg_add_sub_test(N, n) @test all(abs.(projection(W₁ - W₂) .- S₄) .< 1e-10) end +function stiefel_lie_alg_vectorization_test(N, n; T=Float32) + A = rand(StiefelLieAlgHorMatrix{T}, N, n) + @test isapprox(StiefelLieAlgHorMatrix(vec(A), N, n), A) +end + # TODO: tests for ADAM functions # test everything for different n & N values @@ -96,4 +101,5 @@ for (N, n) ∈ zip(N_vec, n_vec) skew_mat_mul_test2(N) stiefel_proj_test(N,n) stiefel_lie_alg_add_sub_test(N,n) + stiefel_lie_alg_vectorization_test(N, n) end diff --git a/test/runtests.jl b/test/runtests.jl index 033b04f16..1b6cb1b4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,13 +11,15 @@ using SafeTestsets @safetestset "Manifold Neural Network Layers " begin include("layers/manifold_layers.jl") end @safetestset "Custom AD rules for kernels " begin include("custom_ad_rules/kernel_pullbacks.jl") end @safetestset "ResNet " begin include("layers/resnet_tests.jl") end -@safetestset "Transformer Networks #1 " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end -@safetestset "Transformer Networks #2 " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end -@safetestset "Transformer Networks #3 " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end -@safetestset "Transformer Networks #4 " begin include("transformer_related/transformer_setup.jl") end -@safetestset "Transformer Networks #5 " begin include("transformer_related/transformer_application.jl") end -@safetestset "Transformer Networks #6 " begin include("transformer_related/transformer_gradient.jl") end -@safetestset "Transformer Networks #7 " begin include("transformer_related/transformer_optimizer.jl") end + +# transformer-related tests +@safetestset "Test setup of MultiHeadAttention layer Stiefel weights " begin include("transformer_related/multi_head_attention_stiefel_setup.jl") end +@safetestset "Test geodesic and Cayley retr for the MultiHeadAttention layer w/ St weights " begin include("transformer_related/multi_head_attention_stiefel_retraction.jl") end +@safetestset "Test the correct setup of the various optimizer caches for MultiHeadAttention " begin include("transformer_related/multi_head_attention_stiefel_optim_cache.jl") end +@safetestset "Check if the transformer can be applied to a tensor. " begin include("transformer_related/transformer_application.jl") end +@safetestset "Check if the gradient/pullback of MultiHeadAttention changes type in St case " begin include("transformer_related/transformer_gradient.jl") end +@safetestset "Check if the optimization_step! changes the parameters of the transformer " begin include("transformer_related/transformer_optimizer.jl") end + @safetestset "Attention layer #1 " begin include("attention_layer/attention_setup.jl") end @safetestset "(MultiHead)Attention " begin include("attention_layer/apply_multi_head_attention.jl") end @safetestset "Classification layer " begin include("layers/classification.jl") end diff --git a/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl b/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl index ca4209c39..c7614fc52 100644 --- a/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl +++ b/test/transformer_related/multi_head_attention_stiefel_optim_cache.jl @@ -2,19 +2,10 @@ using GeometricMachineLearning, Test import Lux, Random, LinearAlgebra -dim = 64 -n_heads = 8 -Dₕ = dim÷8 -tol = eps(Float32) - -model = Chain(MultiHeadAttention(dim, n_heads), MultiHeadAttention(dim, n_heads, Stiefel=true)) -ps = initialparameters(CPU(), Float32, model) - -o₁ = Optimizer(AdamOptimizer(), ps) -o₂ = Optimizer(MomentumOptimizer(), ps) -o₃ = Optimizer(GradientOptimizer(), ps) - -function check_adam_cache(C::AbstractCache) +@doc raw""" +This checks if the Adam cache was set up in the right way +""" +function check_adam_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T @test typeof(C) <: AdamCache @test propertynames(C) == (:B₁, :B₂) @test typeof(C.B₁) <: StiefelLieAlgHorMatrix @@ -24,7 +15,10 @@ function check_adam_cache(C::AbstractCache) end check_adam_cache(B::NamedTuple) = apply_toNT(check_adam_cache, B) -function check_momentum_cache(C::AbstractCache) +@doc raw""" +This checks if the momentum cache was set up in the right way +""" +function check_momentum_cache(C::AbstractCache{T}, tol= T(10) * eps(T)) where T @test typeof(C) <: MomentumCache @test propertynames(C) == (:B,) @test typeof(C.B) <: StiefelLieAlgHorMatrix @@ -32,12 +26,33 @@ function check_momentum_cache(C::AbstractCache) end check_momentum_cache(B::NamedTuple) = apply_toNT(check_momentum_cache, B) -function check_gradient_cache(C::AbstractCache) +@doc raw""" +This checks if the gradient cache was set up in the right way +""" +function check_gradient_cache(C::AbstractCache{T}) where T @test typeof(C) <: GradientCache @test propertynames(C) == () end check_gradient_cache(B::NamedTuple) = apply_toNT(check_gradient_cache, B) -check_adam_cache(o₁.cache[2]) -check_momentum_cache(o₂.cache[2]) -check_gradient_cache(o₃.cache[2]) +@doc raw""" +This checks if all the caches are set up in the right way for the `MultiHeadAttention` layer with Stiefel weights. + +TODO: +- [ ] `BFGSOptimizer` !! +""" +function test_cache_setups_for_optimizer_for_multihead_attention_layer(T::Type, dim::Int, n_heads::Int) + @assert dim % n_heads == 0 + model = MultiHeadAttention(dim, n_heads, Stiefel=true) + ps = initialparameters(CPU(), T, model) + + o₁ = Optimizer(AdamOptimizer(), ps) + o₂ = Optimizer(MomentumOptimizer(), ps) + o₃ = Optimizer(GradientOptimizer(), ps) + + check_adam_cache(o₁.cache) + check_momentum_cache(o₂.cache) + check_gradient_cache(o₃.cache) +end + +test_cache_setups_for_optimizer_for_multihead_attention_layer(Float32, 64, 8) \ No newline at end of file diff --git a/test/transformer_related/multi_head_attention_stiefel_retraction.jl b/test/transformer_related/multi_head_attention_stiefel_retraction.jl index 842069da5..54f6214e8 100644 --- a/test/transformer_related/multi_head_attention_stiefel_retraction.jl +++ b/test/transformer_related/multi_head_attention_stiefel_retraction.jl @@ -1,7 +1,3 @@ -""" -This is a test for that checks if the retractions (geodesic and Cayley for now) map from StiefelLieAlgHorMatrix to StiefelManifold when used with MultiHeadAttention. -""" - import Random, Test, Lux, LinearAlgebra, KernelAbstractions using GeometricMachineLearning, Test @@ -9,37 +5,40 @@ using GeometricMachineLearning: geodesic using GeometricMachineLearning: cayley using GeometricMachineLearning: init_optimizer_cache -dim = 64 -n_heads = 8 -Dₕ = dim÷8 -tol = eps(Float32) -T = Float32 -backend = KernelAbstractions.CPU() - -model = MultiHeadAttention(dim, n_heads, Stiefel=true) - -ps = initialparameters(backend, T, model) - -cache = init_optimizer_cache(MomentumOptimizer(), ps) - -E = StiefelProjection(dim, Dₕ, T) -function check_retraction_geodesic(A::AbstractMatrix) +@doc raw""" +This function computes the geodesic retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`. +""" +function check_retraction_geodesic(A::AbstractMatrix{T}, tol=eps(T)) where T A_retracted = geodesic(A) @test typeof(A_retracted) <: StiefelManifold - @test LinearAlgebra.norm(A_retracted - E) < tol + @test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol end check_retraction_geodesic(cache::NamedTuple) = apply_toNT(check_retraction_geodesic, cache) check_retraction_geodesic(B::MomentumCache) = check_retraction_geodesic(B.B) -check_retraction_geodesic(cache) - -E = StiefelProjection(dim, Dₕ) -function check_retraction_cayley(A::AbstractMatrix) +@doc raw""" +This function computes the cayley retraction of an element of `StiefelLieAlgHorMatrix` and then checks if the resulting element is `StiefelProjection`. +""" +function check_retraction_cayley(A::AbstractMatrix{T}, tol=eps(T)) where T A_retracted = cayley(A) @test typeof(A_retracted) <: StiefelManifold - @test LinearAlgebra.norm(A_retracted - E) < tol + @test LinearAlgebra.norm(A_retracted - StiefelProjection(A_retracted)) < tol end check_retraction_cayley(cache::NamedTuple) = apply_toNT(check_retraction_cayley, cache) check_retraction_cayley(B::MomentumCache) = check_retraction_cayley(B.B) -check_retraction_cayley(cache) +@doc raw""" +This is a test for that checks if the retractions (geodesic and Cayley for now) map from `StiefelLieAlgHorMatrix` to `StiefelManifold` when used with `MultiHeadAttention`. +""" +function test_multi_head_attention_retraction(T::Type, dim, n_heads, tol=eps(T), backend=KernelAbstractions.CPU()) + model = MultiHeadAttention(dim, n_heads, Stiefel=true) + + ps = initialparameters(backend, T, model) + cache = init_optimizer_cache(MomentumOptimizer(), ps) + + check_retraction_geodesic(cache) + + check_retraction_cayley(cache) +end + +test_multi_head_attention_retraction(Float32, 64, 8) \ No newline at end of file diff --git a/test/transformer_related/multi_head_attention_stiefel_setup.jl b/test/transformer_related/multi_head_attention_stiefel_setup.jl index 3b8cc6cbf..742839331 100644 --- a/test/transformer_related/multi_head_attention_stiefel_setup.jl +++ b/test/transformer_related/multi_head_attention_stiefel_setup.jl @@ -3,25 +3,36 @@ import Random, Test, Lux, LinearAlgebra, KernelAbstractions using GeometricMachineLearning, Test using GeometricMachineLearning: init_optimizer_cache -T = Float32 -model = MultiHeadAttention(64, 8, Stiefel=true) -ps = initialparameters(KernelAbstractions.CPU(), T, model) -tol = 10*eps(T) - -function check_setup(A::AbstractMatrix) +@doc raw""" +This checks for an arbitrary matrix ``A\in\mathbb{R}^{N\times{}n}`` if ``A\in{}St(n,N)``. +""" +function check_setup(A::AbstractMatrix{T}, tol=T(10)*eps(T)) where T @test typeof(A) <: StiefelManifold @test check(A) < tol end check_setup(ps::NamedTuple) = apply_toNT(check_setup, ps) -check_setup(ps) -######## check if the gradients are set up the correct way -function check_grad_setup(B::AbstractMatrix) +@doc raw""" +This checks for an arbitrary matrix ``B\in\mathbb{R}^{N\times{}N}`` if ``B\in\mathfrak{g}^\mathrm{hor}``. +""" +function check_grad_setup(B::AbstractMatrix{T}, tol=T(10)*eps(T)) where T @test typeof(B) <: StiefelLieAlgHorMatrix @test LinearAlgebra.norm(B) < tol end check_grad_setup(gx::NamedTuple) = apply_toNT(check_grad_setup, gx) check_grad_setup(B::MomentumCache) = check_grad_setup(B.B) -gx = init_optimizer_cache(MomentumOptimizer(), ps) -check_grad_setup(gx) +@doc raw""" +Check if `initialparameters` and `init_optimizer_cache` do the right thing for `MultiHeadAttentionLayer`. +""" +function check_multi_head_attention_stiefel_setup(T::Type, N::Int, n::Int) + model = MultiHeadAttention(N, n, Stiefel=true) + ps = initialparameters(KernelAbstractions.CPU(), T, model) + + check_setup(ps) + + gx = init_optimizer_cache(MomentumOptimizer(), ps) + check_grad_setup(gx) +end + +check_multi_head_attention_stiefel_setup(Float32, 64, 8) \ No newline at end of file diff --git a/test/transformer_related/transformer_application.jl b/test/transformer_related/transformer_application.jl index 7fc159a8b..6729c889e 100644 --- a/test/transformer_related/transformer_application.jl +++ b/test/transformer_related/transformer_application.jl @@ -1,5 +1,8 @@ using Test, KernelAbstractions, GeometricMachineLearning +@doc raw""" +This tests if the size of the input array is kept constant when fed into the transformer (for a matrix and a tensor) +""" function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_size=10) model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim)) model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim)) @@ -8,11 +11,11 @@ function transformer_application_test(T, dim, n_heads, L, seq_length=8, batch_si ps₂ = initialparameters(KernelAbstractions.CPU(), T, model₂) input₁ = rand(T, dim, seq_length, batch_size) - input₂ = rand(T, dim, seq_length, batch_size) - + input₂ = rand(T, dim, seq_length) + @test size(model₁(input₁, ps₁)) == size(input₁) - @test size(model₁(input₂, ps₁)) == size(input₂) @test size(model₂(input₁, ps₂)) == size(input₁) + @test size(model₁(input₂, ps₁)) == size(input₂) @test size(model₂(input₂, ps₂)) == size(input₂) end diff --git a/test/transformer_related/transformer_gradient.jl b/test/transformer_related/transformer_gradient.jl index e049933a6..2a53673fc 100644 --- a/test/transformer_related/transformer_gradient.jl +++ b/test/transformer_related/transformer_gradient.jl @@ -1,5 +1,8 @@ using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra +@doc raw""" +This checks if the gradients of the transformer change the type in case of the Stiefel manifold, and checks if they stay the same in the case of regular weights. +""" function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10) model₁ = Chain(Transformer(dim, n_heads, L, Stiefel=false), ResNet(dim)) model₂ = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim)) diff --git a/test/transformer_related/transformer_optimizer.jl b/test/transformer_related/transformer_optimizer.jl index aa3968c62..e6421bdb7 100644 --- a/test/transformer_related/transformer_optimizer.jl +++ b/test/transformer_related/transformer_optimizer.jl @@ -1,5 +1,8 @@ using Test, KernelAbstractions, GeometricMachineLearning, Zygote, LinearAlgebra +@doc raw""" +This function tests if the `GradientOptimzier`, `MomentumOptimizer`, `AdamOptimizer` and `BFGSOptimizer` act on the neural network weights via `optimization_step!`. +""" function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size=10) model = Chain(Transformer(dim, n_heads, L, Stiefel=true), ResNet(dim)) model = Transformer(dim, n_heads, L, Stiefel=true) @@ -14,15 +17,19 @@ function transformer_gradient_test(T, dim, n_heads, L, seq_length=8, batch_size= o₁ = Optimizer(GradientOptimizer(), ps) o₂ = Optimizer(MomentumOptimizer(), ps) o₃ = Optimizer(AdamOptimizer(), ps) + o₄ = Optimizer(BFGSOptimizer(), ps) ps₁ = deepcopy(ps) ps₂ = deepcopy(ps) ps₃ = deepcopy(ps) + ps₄ = deepcopy(ps) optimization_step!(o₁, model, ps₁, dx) optimization_step!(o₂, model, ps₂, dx) optimization_step!(o₃, model, ps₃, dx) - @test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps) + optimization_step!(o₄, model, ps₄, dx) + @test typeof(ps₁) == typeof(ps₂) == typeof(ps₃) == typeof(ps₄) == typeof(ps) + @test ps₁[1].PQ.head_1 ≉ ps₂[1].PQ.head_1 ≉ ps₃[1].PQ.head_1 ≉ ps₄[1].PQ.head_1 ≉ ps[1].PQ.head_1 end transformer_gradient_test(Float32, 10, 5, 4) \ No newline at end of file diff --git a/test/transformer_related/transformer_setup.jl b/test/transformer_related/transformer_setup.jl index 42ac78995..1311ce26c 100644 --- a/test/transformer_related/transformer_setup.jl +++ b/test/transformer_related/transformer_setup.jl @@ -1,5 +1,8 @@ using Test, KernelAbstractions, GeometricMachineLearning +@doc raw""" +This function tests the setup of the transformer with Stiefel weights. +""" function transformer_setup_test(dim, n_heads, L, T) model = Transformer(dim, n_heads, L, Stiefel=true) ps = initialparameters(KernelAbstractions.CPU(), T, model)