-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #101 from JuliaGNI/increase_data_loader_test_coverage
Increase data loader test coverage
- Loading branch information
Showing
12 changed files
with
196 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,43 @@ | ||
|
||
@doc raw""" | ||
`SymplecticMatrix(n)` | ||
`SymplecticPotential(n)` | ||
Returns a symplectic matrix of size 2n x 2n | ||
```math | ||
\begin{pmatrix} | ||
0 & & & 1 & & & \\ | ||
& \ddots & & & \ddots & & \\ | ||
& & 0 & & & 1 \\ | ||
-1 & & & 0 & & & \\ | ||
& \ddots & & & \ddots & & \\ | ||
& & -1 & & 0 & \\ | ||
\mathbb{O} & \mathbb{I} \\ | ||
\mathbb{O} & -\mathbb{I} \\ | ||
\end{pmatrix} | ||
``` | ||
`SymplecticProjection(N,n)` | ||
Returns the symplectic projection matrix E of the Stiefel manifold, i.e. π: Sp(2N) → Sp(2n,2N), A ↦ AE | ||
""" | ||
#= | ||
function SymplecticMatrix(n::Int, T::DataType=Float64) | ||
BandedMatrix((n => ones(T,n), -n => -ones(T,n)), (2n,2n)) | ||
end | ||
SymplecticMatrix(T::DataType, n::Int) = SymplecticMatrix(n, T) | ||
@doc raw""" | ||
```math | ||
\begin{pmatrix} | ||
I & 0 \\ | ||
0 & 0 \\ | ||
0 & I \\ | ||
0 & 0 \\ | ||
\end{pmatrix} | ||
``` | ||
""" | ||
=# | ||
|
||
function SymplecticPotential(n::Int, T::DataType=Float64) | ||
J = zeros(T, 2*n, 2*n) | ||
J[1:n, (n+1):2*n] = one(ones(T, n, n)) | ||
J[(n+1):2*n, 1:n] = -one(ones(T, n, n)) | ||
function SymplecticPotential(backend, n2::Int, T::DataType=Float64) | ||
@assert iseven(n2) | ||
n = n2÷2 | ||
J = KernelAbstractions.zeros(backend, T, 2*n, 2*n) | ||
assign_ones_for_symplectic_potential! = assign_ones_for_symplectic_potential_kernel!(backend) | ||
assign_ones_for_symplectic_potential!(J, n, ndrange=n) | ||
J | ||
end | ||
|
||
SymplecticPotential(n::Int, T::DataType=Float64) = SymplecticPotential(CPU(), n, T) | ||
SymplecticPotential(bakend, T::DataType, n::Int) = SymplecticPotential(backend, n, T) | ||
|
||
SymplecticPotential(T::DataType, n::Int) = SymplecticPotential(n, T) | ||
|
||
struct SymplecticProjection{T} <: AbstractMatrix{T} | ||
N::Int | ||
n::Int | ||
SymplecticProjection(N, n, T = Float64) = new{T}(N,n) | ||
@kernel function assign_ones_for_symplectic_potential_kernel!(J::AbstractMatrix{T}, n::Int) where T | ||
i = @index(Global) | ||
J[map_index_for_symplectic_potential(i, n)...] = i ≤ n ? one(T) : -one(T) | ||
end | ||
|
||
function Base.getindex(E::SymplecticProjection,i,j) | ||
if i ≤ E.n | ||
if j == i | ||
return 1. | ||
end | ||
return 0. | ||
end | ||
if j > E.n | ||
if (j-E.n) == (i-E.N) | ||
return 1. | ||
end | ||
return 0. | ||
""" | ||
This assigns the right index for the symplectic potential. To be used with `assign_ones_for_symplectic_potential_kernel!`. | ||
""" | ||
function map_index_for_symplectic_potential(i::Int, n::Int) | ||
if i ≤ n | ||
return (i, i+n) | ||
else | ||
return (i, i-n) | ||
end | ||
return 0. | ||
end | ||
|
||
|
||
Base.parent(E::SymplecticProjection) = (E.N,E.n) | ||
Base.size(E::SymplecticProjection) = (2*E.N,2*E.n) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using GeometricMachineLearning | ||
using Test | ||
|
||
function dummy_qp_data_matrix(dim=2, number_data_points=200, T=Float32) | ||
(q = rand(T, dim, number_data_points), p = (rand(T, dim, number_data_points))) | ||
end | ||
|
||
function dummy_qp_data_tensor(dim=2, number_of_time_steps=100, number_of_parameters=20, T=Float32) | ||
(q = rand(T, dim, number_of_time_steps, number_of_parameters), p = (rand(T, dim, number_of_time_steps, number_of_parameters))) | ||
end | ||
|
||
function test_data_loader(dim=2, number_of_time_steps=100, number_of_parameters=20, batch_size=10, T=Float32) | ||
|
||
dl1 = DataLoader(dummy_qp_data_matrix(dim, number_of_time_steps, T)) | ||
dl2 = DataLoader(dummy_qp_data_tensor(dim, number_of_time_steps, number_of_parameters)) | ||
|
||
arch1 = GSympNet(dl1) | ||
arch2 = GSympNet(dl2) | ||
|
||
nn1 = NeuralNetwork(arch1, CPU(), T) | ||
nn2 = NeuralNetwork(arch2, CPU(), T) | ||
|
||
loss1 = GeometricMachineLearning.loss(nn1, dl1) | ||
loss2 = GeometricMachineLearning.loss(nn2, dl2) | ||
|
||
batch = Batch(batch_size) | ||
o₁ = Optimizer(GradientOptimizer(), nn1) | ||
# o₂ = Optimizer(GradientOptimizer(), nn2) | ||
|
||
o₁(nn1, dl1, batch) | ||
# o₂(nn2, dl2, batch) | ||
end | ||
|
||
test_data_loader() |
3 changes: 3 additions & 0 deletions
3
test/data_loader/data_loader.jl → ...a_loader/data_loader_optimization_step.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.