Skip to content

Commit

Permalink
Merge pull request #98 from JuliaGNI/multiplication_and_addition_test…
Browse files Browse the repository at this point in the history
…s_for_custom_arrays

Multiplication and addition tests for custom arrays
  • Loading branch information
michakraus authored Dec 20, 2023
2 parents 9dbdb58 + 1e659d7 commit eb94a21
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ module GeometricMachineLearning
export convert_to_dev, Device, CPUDevice

# INCLUDE ARRAYS
include("arrays/skew_symmetric.jl")
include("arrays/symmetric.jl")
include("arrays/symplectic.jl")
include("arrays/skew_symmetric.jl")
include("arrays/abstract_lie_algebra_horizontal.jl")
include("arrays/stiefel_lie_algebra_horizontal.jl")
include("arrays/grassmann_lie_algebra_horizontal.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/arrays/skew_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ end

# the first matrix is multiplied onto A2 in order for it to not be SkewSymMatrix!
function Base.:*(A1::SkewSymMatrix{T}, A2::SkewSymMatrix{T}) where T
A1*(one(A2)*A2)
A1 * (one(A2) * A2)
end

@doc raw"""
Expand Down
14 changes: 14 additions & 0 deletions src/arrays/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,27 @@ function Base.:*(A::SymmetricMatrix{T}, B::AbstractMatrix{T}) where T
C
end

Base.:*(B::AbstractMatrix{T}, A::SymmetricMatrix{T}) where T = (A * B')'

function Base.:*(A::SymmetricMatrix{T}, B::SymmetricMatrix{T}) where T
A * (B * one(B))
end

function Base.:*(A::SymmetricMatrix{T}, b::AbstractVector{T}) where T
backend = KernelAbstractions.get_backend(A.S)
c = KernelAbstractions.allocate(backend, T, A.n)
LinearAlgebra.mul!(c, A, b)
c
end

function Base.one(A::SymmetricMatrix{T}) where T
backend = KernelAbstractions.get_backend(A.S)
unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n)
write_ones! = write_ones_kernel!(backend)
write_ones!(unit_matrix, ndrange=A.n)
unit_matrix
end

# define routines for generalizing ChainRulesCore to SymmetricMatrix
ChainRulesCore.ProjectTo(A::SymmetricMatrix) = ProjectTo{SymmetricMatrix}(; symmetric=ProjectTo(A.S))
(project::ProjectTo{SymmetricMatrix})(dA::AbstractMatrix) = SymmetricMatrix(project.symmetric(map_to_S(dA)), size(dA, 2))
Expand Down
53 changes: 41 additions & 12 deletions test/arrays/addition_tests_for_custom_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
#=
This tests addition for all custom arrays. Note that these tests will also have to be performed on GPU!
=#

using LinearAlgebra
using Random
using Test
using GeometricMachineLearning

function test_addition_for_symmetric_matrix(n::Int, T::Type)
A = rand(SymmetricMatrix{T}, n)
end
using GeometricMachineLearning, Test

@doc raw"""
This function tests addition for various custom arrays, i.e. if \(A + B\) is performed in the correct way.
"""
function addition_tests_for_custom_arrays(n::Int, N::Int, T::Type)
A = rand(T, n, n)
B = rand(T, n, n)

# SymmetricMatrix
AB_sym = SymmetricMatrix(A + B)
AB_sym2 = SymmetricMatrix(A) + SymmetricMatrix(B)
@test AB_sym AB_sym2
@test typeof(AB_sym) <: SymmetricMatrix{T}
@test typeof(AB_sym2) <: SymmetricMatrix{T}

# SkewSymMatrix
AB_skew = SkewSymMatrix(A + B)
AB_skew2 = SkewSymMatrix(A) + SkewSymMatrix(B)
@test AB_skew AB_skew2
@test typeof(AB_skew) <: SkewSymMatrix{T}
@test typeof(AB_skew2) <: SkewSymMatrix{T}

C = rand(T, N, N)
D = rand(T, N, N)

# StiefelLieAlgHorMatrix
CD_slahm = StiefelLieAlgHorMatrix(C + D, n)
CD_slahm2 = StiefelLieAlgHorMatrix(C, n) + StiefelLieAlgHorMatrix(D, n)
@test CD_slahm CD_slahm2
@test typeof(CD_slahm) <: StiefelLieAlgHorMatrix{T}
@test typeof(CD_slahm2) <: StiefelLieAlgHorMatrix{T}

CD_glahm = GrassmannLieAlgHorMatrix(C + D, n)
CD_glahm2 = GrassmannLieAlgHorMatrix(C, n) + GrassmannLieAlgHorMatrix(D, n)
@test CD_glahm CD_glahm2
@test typeof(CD_glahm) <: GrassmannLieAlgHorMatrix{T}
@test typeof(CD_glahm2) <: GrassmannLieAlgHorMatrix{T}
end

addition_tests_for_custom_arrays(5, 10, Float32)
1 change: 0 additions & 1 deletion test/arrays/array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ function skew_mat_mul_test2(n, T=Float64)
AS2 = A*Matrix{T}(S)
@test isapprox(AS1, AS2)
end

# test Stiefel manifold projection test
function stiefel_proj_test(N,n)
In = I(n)
Expand Down
21 changes: 21 additions & 0 deletions test/arrays/matrix_multiplication_for_custom_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using GeometricMachineLearning, Test

@doc raw"""
This function tests matrix multiplication for various custom arrays, i.e. if \((A,\alpha) \mapsto \alpha{}A\) is performed in the correct way.
"""
function matrix_multiplication_tests_for_custom_arrays(n::Int, N::Int, T::Type)
A = rand(T, n, n)
B = rand(T, n, N)

# SymmetricMatrix
A_sym = SymmetricMatrix(A)
@test A_sym * B Matrix{T}(A_sym) * B
@test B' * A_sym B' * Matrix{T}(A_sym)

# SkewSymMatrix
A_skew = SkewSymMatrix(A)
@test A_skew * B Matrix{T}(A_skew) * B
@test B' * A_skew B' * Matrix{T}(A_skew)
end

matrix_multiplication_tests_for_custom_arrays(5, 10, Float32)
40 changes: 40 additions & 0 deletions test/arrays/scalar_multiplication_for_custom_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using GeometricMachineLearning, Test

@doc raw"""
This function tests scalar multiplication for various custom arrays, i.e. if \((A,\alpha) \mapsto \alpha{}A\) is performed in the correct way.
"""
function scalar_multiplication_for_custom_arrays(n::Int, N::Int, T::Type)
A = rand(T, n, n)
α = rand(T)

# SymmetricMatrix
Aα_sym = SymmetricMatrix* A)
Aα_sym2 = α * SymmetricMatrix(A)
@test Aα_sym Aα_sym2
@test typeof(Aα_sym) <: SymmetricMatrix{T}
@test typeof(Aα_sym2) <: SymmetricMatrix{T}

# SkewSymMatrix
Aα_skew = SkewSymMatrix* A)
Aα_skew2 = α * SkewSymMatrix(A)
@test Aα_skew Aα_skew2
@test typeof(Aα_skew) <: SkewSymMatrix{T}
@test typeof(Aα_skew2) <: SkewSymMatrix{T}

C = rand(T, N, N)

# StiefelLieAlgHorMatrix
Cα_slahm = StiefelLieAlgHorMatrix* C, n)
Cα_slahm2 = α * StiefelLieAlgHorMatrix(C, n)
@test Cα_slahm Cα_slahm2
@test typeof(Cα_slahm) <: StiefelLieAlgHorMatrix{T}
@test typeof(Cα_slahm2) <: StiefelLieAlgHorMatrix{T}

Cα_glahm = GrassmannLieAlgHorMatrix* C, n)
Cα_glahm2 = α * GrassmannLieAlgHorMatrix(C, n)
@test Cα_glahm Cα_glahm2
@test typeof(Cα_glahm) <: GrassmannLieAlgHorMatrix{T}
@test typeof(Cα_glahm2) <: GrassmannLieAlgHorMatrix{T}
end

scalar_multiplication_for_custom_arrays(5, 10, Float32)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using SafeTestsets
@safetestset "Check parameterlength " begin include("parameterlength/check_parameterlengths.jl") end
@safetestset "Arrays #1 " begin include("arrays/array_tests.jl") end
@safetestset "Sampling of arrays " begin include("arrays/random_generation_of_custom_arrays.jl") end
@safetestset "Addition tests for custom arrays " begin include("arrays/addition_tests_for_custom_arrays.jl") end
@safetestset "Scalar multiplication tests for custom arrays " begin include("arrays/scalar_multiplication_for_custom_arrays.jl") end
@safetestset "Matrix multiplication tests for custom arrays " begin include("arrays/matrix_multiplication_for_custom_arrays.jl") end
@safetestset "Manifolds (Grassmann): " begin include("manifolds/grassmann_manifold.jl") end
@safetestset "Gradient Layer " begin include("layers/gradient_layer_tests.jl") end
@safetestset "Test symplecticity of upscaling layer " begin include("layers/sympnet_layers_test.jl") end
Expand Down

0 comments on commit eb94a21

Please sign in to comment.