-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #98 from JuliaGNI/multiplication_and_addition_test…
…s_for_custom_arrays Multiplication and addition tests for custom arrays
- Loading branch information
Showing
8 changed files
with
121 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,41 @@ | ||
#= | ||
This tests addition for all custom arrays. Note that these tests will also have to be performed on GPU! | ||
=# | ||
|
||
using LinearAlgebra | ||
using Random | ||
using Test | ||
using GeometricMachineLearning | ||
|
||
function test_addition_for_symmetric_matrix(n::Int, T::Type) | ||
A = rand(SymmetricMatrix{T}, n) | ||
end | ||
using GeometricMachineLearning, Test | ||
|
||
@doc raw""" | ||
This function tests addition for various custom arrays, i.e. if \(A + B\) is performed in the correct way. | ||
""" | ||
function addition_tests_for_custom_arrays(n::Int, N::Int, T::Type) | ||
A = rand(T, n, n) | ||
B = rand(T, n, n) | ||
|
||
# SymmetricMatrix | ||
AB_sym = SymmetricMatrix(A + B) | ||
AB_sym2 = SymmetricMatrix(A) + SymmetricMatrix(B) | ||
@test AB_sym ≈ AB_sym2 | ||
@test typeof(AB_sym) <: SymmetricMatrix{T} | ||
@test typeof(AB_sym2) <: SymmetricMatrix{T} | ||
|
||
# SkewSymMatrix | ||
AB_skew = SkewSymMatrix(A + B) | ||
AB_skew2 = SkewSymMatrix(A) + SkewSymMatrix(B) | ||
@test AB_skew ≈ AB_skew2 | ||
@test typeof(AB_skew) <: SkewSymMatrix{T} | ||
@test typeof(AB_skew2) <: SkewSymMatrix{T} | ||
|
||
C = rand(T, N, N) | ||
D = rand(T, N, N) | ||
|
||
# StiefelLieAlgHorMatrix | ||
CD_slahm = StiefelLieAlgHorMatrix(C + D, n) | ||
CD_slahm2 = StiefelLieAlgHorMatrix(C, n) + StiefelLieAlgHorMatrix(D, n) | ||
@test CD_slahm ≈ CD_slahm2 | ||
@test typeof(CD_slahm) <: StiefelLieAlgHorMatrix{T} | ||
@test typeof(CD_slahm2) <: StiefelLieAlgHorMatrix{T} | ||
|
||
CD_glahm = GrassmannLieAlgHorMatrix(C + D, n) | ||
CD_glahm2 = GrassmannLieAlgHorMatrix(C, n) + GrassmannLieAlgHorMatrix(D, n) | ||
@test CD_glahm ≈ CD_glahm2 | ||
@test typeof(CD_glahm) <: GrassmannLieAlgHorMatrix{T} | ||
@test typeof(CD_glahm2) <: GrassmannLieAlgHorMatrix{T} | ||
end | ||
|
||
addition_tests_for_custom_arrays(5, 10, Float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
using GeometricMachineLearning, Test | ||
|
||
@doc raw""" | ||
This function tests matrix multiplication for various custom arrays, i.e. if \((A,\alpha) \mapsto \alpha{}A\) is performed in the correct way. | ||
""" | ||
function matrix_multiplication_tests_for_custom_arrays(n::Int, N::Int, T::Type) | ||
A = rand(T, n, n) | ||
B = rand(T, n, N) | ||
|
||
# SymmetricMatrix | ||
A_sym = SymmetricMatrix(A) | ||
@test A_sym * B ≈ Matrix{T}(A_sym) * B | ||
@test B' * A_sym ≈ B' * Matrix{T}(A_sym) | ||
|
||
# SkewSymMatrix | ||
A_skew = SkewSymMatrix(A) | ||
@test A_skew * B ≈ Matrix{T}(A_skew) * B | ||
@test B' * A_skew ≈ B' * Matrix{T}(A_skew) | ||
end | ||
|
||
matrix_multiplication_tests_for_custom_arrays(5, 10, Float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
using GeometricMachineLearning, Test | ||
|
||
@doc raw""" | ||
This function tests scalar multiplication for various custom arrays, i.e. if \((A,\alpha) \mapsto \alpha{}A\) is performed in the correct way. | ||
""" | ||
function scalar_multiplication_for_custom_arrays(n::Int, N::Int, T::Type) | ||
A = rand(T, n, n) | ||
α = rand(T) | ||
|
||
# SymmetricMatrix | ||
Aα_sym = SymmetricMatrix(α * A) | ||
Aα_sym2 = α * SymmetricMatrix(A) | ||
@test Aα_sym ≈ Aα_sym2 | ||
@test typeof(Aα_sym) <: SymmetricMatrix{T} | ||
@test typeof(Aα_sym2) <: SymmetricMatrix{T} | ||
|
||
# SkewSymMatrix | ||
Aα_skew = SkewSymMatrix(α * A) | ||
Aα_skew2 = α * SkewSymMatrix(A) | ||
@test Aα_skew ≈ Aα_skew2 | ||
@test typeof(Aα_skew) <: SkewSymMatrix{T} | ||
@test typeof(Aα_skew2) <: SkewSymMatrix{T} | ||
|
||
C = rand(T, N, N) | ||
|
||
# StiefelLieAlgHorMatrix | ||
Cα_slahm = StiefelLieAlgHorMatrix(α * C, n) | ||
Cα_slahm2 = α * StiefelLieAlgHorMatrix(C, n) | ||
@test Cα_slahm ≈ Cα_slahm2 | ||
@test typeof(Cα_slahm) <: StiefelLieAlgHorMatrix{T} | ||
@test typeof(Cα_slahm2) <: StiefelLieAlgHorMatrix{T} | ||
|
||
Cα_glahm = GrassmannLieAlgHorMatrix(α * C, n) | ||
Cα_glahm2 = α * GrassmannLieAlgHorMatrix(C, n) | ||
@test Cα_glahm ≈ Cα_glahm2 | ||
@test typeof(Cα_glahm) <: GrassmannLieAlgHorMatrix{T} | ||
@test typeof(Cα_glahm2) <: GrassmannLieAlgHorMatrix{T} | ||
end | ||
|
||
scalar_multiplication_for_custom_arrays(5, 10, Float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters