Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase data loader test coverage #101

Merged
merged 32 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
35b3db6
Renamed file to something more descriptive.
benedict-96 Dec 19, 2023
66f3279
Renamed file to something meaningful and made the test more readable.
benedict-96 Dec 19, 2023
257aeee
Added descriptions and a wrapper for using the custom loss function w…
benedict-96 Dec 19, 2023
9ca60ed
Added descriptions and a wrapper for using the custom loss function w…
benedict-96 Dec 19, 2023
fced8d8
Added documentation.
benedict-96 Dec 19, 2023
d95268d
Added constructor for optimizer if input arguments are flipped.
benedict-96 Dec 19, 2023
2abed87
Added a comment saying that the constructor can be called with DataLo…
benedict-96 Dec 19, 2023
0f62b8f
Added default for number of epochs.
benedict-96 Dec 19, 2023
3ab42f2
Combined matrix and tensor routines into one. Added another loss for …
benedict-96 Dec 19, 2023
5d22855
Commented out a section that is probably not needed.
benedict-96 Dec 19, 2023
86485ad
Test data loader for qp data.
benedict-96 Dec 19, 2023
cac1f18
Renamed and added tests.
benedict-96 Dec 19, 2023
4528053
Adjusted symplectic matrix.
benedict-96 Dec 19, 2023
6ed5c62
Renamed file to something more descriptive.
benedict-96 Dec 19, 2023
f4ddd3f
Renamed file to something meaningful and made the test more readable.
benedict-96 Dec 19, 2023
92ac2fe
Added descriptions and a wrapper for using the custom loss function w…
benedict-96 Dec 19, 2023
cb19f5b
Added descriptions and a wrapper for using the custom loss function w…
benedict-96 Dec 19, 2023
cb7725f
Added documentation.
benedict-96 Dec 19, 2023
90bbaa7
Added constructor for optimizer if input arguments are flipped.
benedict-96 Dec 19, 2023
137d640
Added a comment saying that the constructor can be called with DataLo…
benedict-96 Dec 19, 2023
b57d148
Added default for number of epochs.
benedict-96 Dec 19, 2023
3100b0b
Combined matrix and tensor routines into one. Added another loss for …
benedict-96 Dec 19, 2023
718174f
Commented out a section that is probably not needed.
benedict-96 Dec 19, 2023
370ee05
Test data loader for qp data.
benedict-96 Dec 19, 2023
14d6ae2
Renamed and added tests.
benedict-96 Dec 19, 2023
9edec7b
Adjusted symplectic matrix.
benedict-96 Dec 19, 2023
2acdd24
dim was missing for specifying default.
benedict-96 Dec 19, 2023
f8b8fff
dl has no field data.
benedict-96 Dec 19, 2023
8c4c73b
Fixed typos.
benedict-96 Dec 19, 2023
fc92faa
Merge branch 'increase_data_loader_test_coverage' of https://github.c…
benedict-96 Dec 19, 2023
d072594
Removed method that appeared twice for some reason.
benedict-96 Dec 19, 2023
d071839
Forgot to commit before.
benedict-96 Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/architectures/sympnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TODO:
abstract type SympNet{AT} <: Architecture end

@doc raw"""
`LASympNet` is called with **a single input argument**, the **system dimension**. Optional input arguments are:
`LASympNet` is called with **a single input argument**, the **system dimension**, or with an instance of `DataLoader`. Optional input arguments are:
- `depth::Int`: The number of linear layers that are applied. The default is 5.
- `nhidden::Int`: The number of hidden layers (i.e. layers that are **not** input or output layers). The default is 2.
- `activation`: The activation function that is applied. By default this is `tanh`.
Expand All @@ -32,7 +32,7 @@ end
@inline AbstractNeuralNetworks.dim(arch::SympNet) = arch.dim

@doc raw"""
`GSympNet` is called with **a single input argument**, the **system dimension**. Optional input arguments are:
`GSympNet` is called with **a single input argument**, the **system dimension**, or with an instance of `DataLoader`. Optional input arguments are:
- `upscaling_dimension::Int`: The *upscaling dimension* of the gradient layer. See the documentation for `GradientLayerQ` and `GradientLayerP` for further explanation. The default is `2*dim`.
- `nhidden::Int`: The number of hidden layers (i.e. layers that are **not** input or output layers). The default is 2.
- `activation`: The activation function that is applied. By default this is `tanh`.
Expand All @@ -49,7 +49,7 @@ struct GSympNet{AT, InitUpper} <: SympNet{AT} where {InitUpper}
end


function GSympNet(dl::DataLoader; upscaling_dimension=2*dim, nhidden=2, activation=tanh, init_upper=true)
function GSympNet(dl::DataLoader; upscaling_dimension=2*dl.input_dim, nhidden=2, activation=tanh, init_upper=true)
new{typeof(activation), init_upper}(dl.input_dim, upscaling_dimension, nhidden, activation)
end
end
Expand Down
80 changes: 24 additions & 56 deletions src/arrays/symplectic.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,43 @@

@doc raw"""

`SymplecticMatrix(n)`
`SymplecticPotential(n)`

Returns a symplectic matrix of size 2n x 2n

```math
\begin{pmatrix}
0 & & & 1 & & & \\
& \ddots & & & \ddots & & \\
& & 0 & & & 1 \\
-1 & & & 0 & & & \\
& \ddots & & & \ddots & & \\
& & -1 & & 0 & \\
\mathbb{O} & \mathbb{I} \\
\mathbb{O} & -\mathbb{I} \\
\end{pmatrix}
```

`SymplecticProjection(N,n)`
Returns the symplectic projection matrix E of the Stiefel manifold, i.e. π: Sp(2N) → Sp(2n,2N), A ↦ AE

"""
#=
function SymplecticMatrix(n::Int, T::DataType=Float64)
BandedMatrix((n => ones(T,n), -n => -ones(T,n)), (2n,2n))
end

SymplecticMatrix(T::DataType, n::Int) = SymplecticMatrix(n, T)

@doc raw"""
```math
\begin{pmatrix}
I & 0 \\
0 & 0 \\
0 & I \\
0 & 0 \\
\end{pmatrix}
```
"""
=#

function SymplecticPotential(n::Int, T::DataType=Float64)
J = zeros(T, 2*n, 2*n)
J[1:n, (n+1):2*n] = one(ones(T, n, n))
J[(n+1):2*n, 1:n] = -one(ones(T, n, n))
function SymplecticPotential(backend, n2::Int, T::DataType=Float64)
@assert iseven(n2)
n = n2÷2
J = KernelAbstractions.zeros(backend, T, 2*n, 2*n)
assign_ones_for_symplectic_potential! = assign_ones_for_symplectic_potential_kernel!(backend)
assign_ones_for_symplectic_potential!(J, n, ndrange=n)

Check warning on line 20 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L15-L20

Added lines #L15 - L20 were not covered by tests
J
end

SymplecticPotential(n::Int, T::DataType=Float64) = SymplecticPotential(CPU(), n, T)
SymplecticPotential(bakend, T::DataType, n::Int) = SymplecticPotential(backend, n, T)

Check warning on line 25 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L24-L25

Added lines #L24 - L25 were not covered by tests

SymplecticPotential(T::DataType, n::Int) = SymplecticPotential(n, T)

struct SymplecticProjection{T} <: AbstractMatrix{T}
N::Int
n::Int
SymplecticProjection(N, n, T = Float64) = new{T}(N,n)
@kernel function assign_ones_for_symplectic_potential_kernel!(J::AbstractMatrix{T}, n::Int) where T
i = @index(Global)
J[map_index_for_symplectic_potential(i, n)...] = i ≤ n ? one(T) : -one(T)

Check warning on line 31 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L29-L31

Added lines #L29 - L31 were not covered by tests
end

function Base.getindex(E::SymplecticProjection,i,j)
if i ≤ E.n
if j == i
return 1.
end
return 0.
end
if j > E.n
if (j-E.n) == (i-E.N)
return 1.
end
return 0.
"""
This assigns the right index for the symplectic potential. To be used with `assign_ones_for_symplectic_potential_kernel!`.
"""
function map_index_for_symplectic_potential(i::Int, n::Int)
if i ≤ n
return (i, i+n)

Check warning on line 39 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L37-L39

Added lines #L37 - L39 were not covered by tests
else
return (i, i-n)

Check warning on line 41 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L41

Added line #L41 was not covered by tests
end
return 0.
end


Base.parent(E::SymplecticProjection) = (E.N,E.n)
Base.size(E::SymplecticProjection) = (2*E.N,2*E.n)
end
21 changes: 18 additions & 3 deletions src/data_loader/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ end
hasseqlength(::Batch{<:Integer}) = true
hasseqlength(::Batch{<:Nothing}) = false

@doc raw"""
The functor for batch is called with an instance on `DataLoader`. It then returns a tuple of batch indices: ``(\mathcal{I}_1, \ldots, \mathcal{I}_{\lceil\mathtt{dl.n\_params/batch\_size}\rceil})``, where the index runs from 1 to the number of batches, which is the number of parameters divided by the batch size (rounded up).
"""
function (batch::Batch{<:Nothing})(dl::DataLoader{T, AT}) where {T, AT<:AbstractArray{T, 3}}
indices = shuffle(1:dl.n_params)
n_batches = Int(ceil(dl.n_params/batch.batch_size))
Expand All @@ -28,6 +31,9 @@ function (batch::Batch{<:Nothing})(dl::DataLoader{T, AT}) where {T, AT<:Abstract
batches
end

@doc raw"""
The functor for batch is called with an instance on `DataLoader`. It then returns a tuple of batch indices: ``(\mathcal{I}_1, \ldots, \mathcal{I}_{\lceil\mathtt{(dl.input\_time\_steps-1)/batch\_size}\rceil})``, where the index runs from 1 to the number of batches, which is the number of input time steps (minus one) divided by the batch size (and rounded up).
"""
function (batch::Batch{<:Nothing})(dl::DataLoader{T, AT}) where {T, BT<:AbstractMatrix{T}, AT<:Union{BT, NamedTuple{(:q, :p), Tuple{BT, BT}}}}
indices = shuffle(1:dl.input_time_steps)
n_batches = Int(ceil((dl.input_time_steps-1)/batch.batch_size))
Expand Down Expand Up @@ -88,7 +94,7 @@ function optimize_for_one_epoch!(opt::Optimizer, nn::NeuralNetwork, dl::DataLoad
end

"""
TODO: Add ProgressMeter!!!
This routine is called if a `DataLoader` storing *symplectic data* (i.e. a `NamedTuple`) is supplied.
"""
function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{Tuple, NamedTuple}, dl::DataLoader{T, AT}, batch::Batch, loss) where {T, AT<:NamedTuple}
count = 0
Expand All @@ -107,7 +113,16 @@ function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{Tuple, NamedTu
total_error/count
end


@doc raw"""
A functor for `Optimizer`. It is called with:
- `nn::NeuralNetwork`
- `dl::DataLoader`
- `batch::Batch`
- `n_epochs::Int`
- `loss`

The last argument is a function through which `Zygote` differentiates. This argument is optional; if it is not supplied `GeometricMachineLearning` defaults to an appropriate loss for the `DataLoader`.
"""
function (o::Optimizer)(nn::NeuralNetwork, dl::DataLoader, batch::Batch, n_epochs::Int, loss)
progress_object = ProgressMeter.Progress(n_epochs; enabled=true)
loss_array = zeros(n_epochs)
Expand All @@ -118,6 +133,6 @@ function (o::Optimizer)(nn::NeuralNetwork, dl::DataLoader, batch::Batch, n_epoch
loss_array
end

function (o::Optimizer)(nn::NeuralNetwork, dl::DataLoader, batch::Batch, n_epochs::Int)
function (o::Optimizer)(nn::NeuralNetwork, dl::DataLoader, batch::Batch, n_epochs::Int=1)
o(nn, dl, batch, n_epochs, loss)
end
35 changes: 26 additions & 9 deletions src/data_loader/data_loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,35 @@ It takes as input:
"""
function loss(model::Union{Chain, AbstractExplicitLayer}, ps::Union{Tuple, NamedTuple}, input::AT, output::BT) where {T, T1, AT<:AbstractArray{T, 3}, BT<:AbstractArray{T1, 3}}
output_estimate = model(input, ps)
norm(output - output_estimate)/norm(output) # /T(sqrt(size(output, 2)*size(output, 3)))
norm(output - output_estimate) / norm(output) # /T(sqrt(size(output, 2)*size(output, 3)))
end

function loss(model::Chain, ps::Tuple, input::BT) where {T, BT<:AbstractArray{T, 3}}
@doc raw"""
The *autoencoder loss*.
"""
function loss(model::Chain, ps::Tuple, input::BT) where {T, BT<:AbstractArray{T}}
output_estimate = model(input, ps)
norm(output_estimate - input)/norm(input) # /T(sqrt(size(input, 2)*size(input, 3)))
norm(output_estimate - input) / norm(input) # /T(sqrt(size(input, 2)*size(input, 3)))
end

function loss(model::Chain, ps::Tuple, input::BT) where {T, BT<:AbstractArray{T, 2}}
nt_diff(A, B) = (q = A.q - B.q, p = A.p - B.p)
nt_norm(A) = norm(A.q) + norm(A.p)

function loss(model::Chain, ps::Tuple, input::NT) where {T, AT<:AbstractArray{T}, NT<:NamedTuple{(:q, :p,), Tuple{AT, AT}}}
output_estimate = model(input, ps)
norm(output_estimate - input)/norm(input) # /T(sqrt(size(input, 2)))
nt_norm(nt_diff(output_estimate, input)) / nt_norm(input)
end

nt_diff(A, B) = (q = A.q - B.q, p = A.p - B.p)
nt_norm(A) = norm(A.q) + norm(A.p)
@doc raw"""
Loss function that takes a `NamedTuple` as input. This should be used with a SympNet (or other neural network-based integrator). It computes:

```math
\mathtt{loss}(\mathcal{NN}, \mathtt{ps}, \begin{pmatrix} q \\ p \end{pmatrix}, \begin{pmatrix} q' \\ p' \end{pmatrix}) \mapsto \left|| \mathcal{NN}(\begin{pmatrix} q \\ p \end{pmatrix}) - \begin{pmatrix} q' \\ p' \end{pmatrix} \right|| / \left|| \begin{pmatrix} q \\ p \end{pmatrix} \right||
```
"""
function loss(model::Chain, ps::Tuple, input::NamedTuple, output::NamedTuple)
output_estimate = model(input, ps)
nt_norm(nt_diff(output_estimate, output))/nt_norm(input)
nt_norm(nt_diff(output_estimate, output)) / nt_norm(input)
end

@doc raw"""
Expand All @@ -133,7 +143,14 @@ function loss(model::Chain, ps::Tuple, dl::DataLoader{T, BT, Nothing}) where {T,
end

function loss(model::Chain, ps::Tuple, dl::DataLoader{T, BT}) where {T, BT<:NamedTuple}
loss(model, ps, dl.data)
loss(model, ps, dl.input)
end

@doc raw"""
Wrapper if we deal with a neural network.
"""
function loss(nn::NeuralNetwork, dl::DataLoader)
loss(nn.model, nn.params, dl)
end

@doc raw"""
Expand Down
2 changes: 2 additions & 0 deletions src/data_loader/tensor_assign.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ function assign_output_estimate(full_output::AbstractArray{T, 3}, prediction_win
output_estimate
end

#=
"""
This function draws random time steps and parameters and based on these assign the batch and the output.

Expand Down Expand Up @@ -101,6 +102,7 @@ function draw_batch!(batch::AT, output::BT, data::AT, target::BT) where {T, T2,
assign_batch!(batch, data, params, time_steps, ndrange=size(batch))
assign_batch!(output, target, params, time_steps, ndrange=size(output))
end
=#

"""
Used for differentiating assign_output_estimate (this appears in the loss).
Expand Down
2 changes: 2 additions & 0 deletions src/optimizers/optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Optimizer(m, nn.params)
end

Optimizer(nn::NeuralNetwork, m::OptimizerMethod) = Optimizer(m, nn)

Check warning on line 25 in src/optimizers/optimizer.jl

View check run for this annotation

Codecov / codecov/patch

src/optimizers/optimizer.jl#L25

Added line #L25 was not covered by tests

#######################################################################################
# optimization step function

Expand Down
33 changes: 0 additions & 33 deletions test/data_loader/batch.jl

This file was deleted.

34 changes: 34 additions & 0 deletions test/data_loader/batch_data_loader_qp_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using GeometricMachineLearning
using Test

function dummy_qp_data_matrix(dim=2, number_data_points=200, T=Float32)
(q = rand(T, dim, number_data_points), p = (rand(T, dim, number_data_points)))
end

function dummy_qp_data_tensor(dim=2, number_of_time_steps=100, number_of_parameters=20, T=Float32)
(q = rand(T, dim, number_of_time_steps, number_of_parameters), p = (rand(T, dim, number_of_time_steps, number_of_parameters)))
end

function test_data_loader(dim=2, number_of_time_steps=100, number_of_parameters=20, batch_size=10, T=Float32)

dl1 = DataLoader(dummy_qp_data_matrix(dim, number_of_time_steps, T))
dl2 = DataLoader(dummy_qp_data_tensor(dim, number_of_time_steps, number_of_parameters))

arch1 = GSympNet(dl1)
arch2 = GSympNet(dl2)

nn1 = NeuralNetwork(arch1, CPU(), T)
nn2 = NeuralNetwork(arch2, CPU(), T)

loss1 = GeometricMachineLearning.loss(nn1, dl1)
loss2 = GeometricMachineLearning.loss(nn2, dl2)

batch = Batch(batch_size)
o₁ = Optimizer(GradientOptimizer(), nn1)
# o₂ = Optimizer(GradientOptimizer(), nn2)

o₁(nn1, dl1, batch)
# o₂(nn2, dl2, batch)
end

test_data_loader()
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using GeometricMachineLearning, Test, Zygote

@doc raw"""
This tests the gradient optimizer called together with the `DataLoader` (applied to a tensor).
"""
function test_data_loader(sys_dim, n_time_steps, n_params, T=Float32)
data = randn(T, sys_dim, n_time_steps, n_params)
dl = DataLoader(data)
Expand Down
Loading
Loading