Skip to content

Commit

Permalink
Adjusted initialparameters inteface to ANN v0.5.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Dec 12, 2024
1 parent a5f3029 commit eef6d29
Show file tree
Hide file tree
Showing 83 changed files with 183 additions and 200 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/mnist/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Here we have chosen a [`ClassificationTransformer`](@ref), i.e. a composition of
We now have to initialize the neural network weights. This is done with the constructor for `NeuralNetwork`:

```@example mnist
backend = GeometricMachineLearning.get_backend(dl)
backend = GeometricMachineLearning.networkbackend(dl)
T = eltype(dl)
nn1 = NeuralNetwork(model1, backend, T)
nn2 = NeuralNetwork(model2, backend, T)
Expand Down
1 change: 1 addition & 0 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module GeometricMachineLearning
import AbstractNeuralNetworks: GlorotUniform
import AbstractNeuralNetworks: params, architecture, model, dim
import AbstractNeuralNetworks: AbstractPullback, NetworkLoss, _compute_loss
import AbstractNeuralNetworks: networkbackend
# export params, architetcure, model
export dim
import GeometricIntegrators.Integrators: method, GeometricIntegrator
Expand Down
12 changes: 6 additions & 6 deletions src/architectures/autoencoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ function encoder(nn::NeuralNetwork{<:AutoEncoder})
NeuralNetwork( UnknownEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks),
encoder_model(nn.architecture),
encoder_parameters(nn),
get_backend(nn))
networkbackend(nn))
end

function _encoder(nn::NeuralNetwork, full_dim::Integer, reduced_dim::Integer)
NeuralNetwork( UnknownEncoder(full_dim, reduced_dim, length(nn.model.layers)),
nn.model,
nn.params,
get_backend(nn))
networkbackend(nn))
end

function input_dimension(::AbstractExplicitLayer{M, N}) where {M, N}
Expand Down Expand Up @@ -242,11 +242,11 @@ end
Obtain the *decoder* from an [`AutoEncoder`](@ref) neural network.
"""
function decoder(nn::NeuralNetwork{<:AutoEncoder})
NeuralNetwork(UnknownDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn))
NeuralNetwork(UnknownDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn))
end

function _decoder(nn::NeuralNetwork, full_dim::Integer, reduced_dim::Integer)
NeuralNetwork(UnknownDecoder(full_dim, reduced_dim, length(nn.model.layers)), nn.model, nn.params, get_backend(nn))
NeuralNetwork(UnknownDecoder(full_dim, reduced_dim, length(nn.model.layers)), nn.model, nn.params, networkbackend(nn))
end

@doc raw"""
Expand All @@ -263,9 +263,9 @@ function decoder(nn::NeuralNetwork)
end

function encoder(nn::NeuralNetwork{<:SymplecticCompression})
NeuralNetwork(UnknownSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), encoder_model(nn.architecture), encoder_parameters(nn), get_backend(nn))
NeuralNetwork(UnknownSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), encoder_model(nn.architecture), encoder_parameters(nn), networkbackend(nn))
end

function decoder(nn::NeuralNetwork{<:SymplecticCompression})
NeuralNetwork(UnknownSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn))
NeuralNetwork(UnknownSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_blocks), decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn))
end
4 changes: 2 additions & 2 deletions src/architectures/neural_network_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ abstract type NeuralNetworkIntegrator <: Architecture end
function Base.iterate(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, ics::AT; n_points = 100) where {T, AT<:AbstractVector{T}}

n_dim = length(ics)
backend = KernelAbstractions.get_backend(ics)
backend = networkbackend(ics)

# Array to store the predictions
valuation = KernelAbstractions.allocate(backend, T, n_dim, n_points)
Expand Down Expand Up @@ -97,7 +97,7 @@ The number of integration steps that should be performed.
function Base.iterate(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, ics::BT; n_points = 100) where {T, AT<:AbstractVector{T}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}}

n_dim2 = length(ics.q)
backend = KernelAbstractions.get_backend(ics.q)
backend = networkbackend(ics.q)

# Array to store the predictions
valuation = (q = KernelAbstractions.allocate(backend, T, n_dim2, n_points), p = KernelAbstractions.allocate(backend, T, n_dim2, n_points))
Expand Down
4 changes: 2 additions & 2 deletions src/architectures/symplectic_autoencoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ end

function encoder(nn::NeuralNetwork{<:SymplecticAutoencoder})
arch = NonLinearSymplecticEncoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_encoder_layers, nn.architecture.n_encoder_blocks, nn.architecture.sympnet_upscale, nn.architecture.activation)
NeuralNetwork(arch, encoder_model(nn.architecture), encoder_parameters(nn), get_backend(nn))
NeuralNetwork(arch, encoder_model(nn.architecture), encoder_parameters(nn), networkbackend(nn))
end

function decoder(nn::NeuralNetwork{<:SymplecticAutoencoder})
arch = NonLinearSymplecticDecoder(nn.architecture.full_dim, nn.architecture.reduced_dim, nn.architecture.n_decoder_layers, nn.architecture.n_decoder_blocks, nn.architecture.sympnet_upscale, nn.architecture.activation)
NeuralNetwork(arch, decoder_model(nn.architecture), decoder_parameters(nn), get_backend(nn))
NeuralNetwork(arch, decoder_model(nn.architecture), decoder_parameters(nn), networkbackend(nn))
end
4 changes: 2 additions & 2 deletions src/architectures/transformer_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::NamedTupl
seq_length = typeof(nn.architecture) <: StandardTransformerIntegrator ? size(ics.q, 2) : nn.architecture.seq_length

n_dim = size(ics.q, 1)
backend = KernelAbstractions.get_backend(ics.q)
backend = networkbackend(ics.q)

n_iterations = Int(ceil((n_points - seq_length) / prediction_window))
# Array to store the predictions
Expand Down Expand Up @@ -84,7 +84,7 @@ function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::AT; n_poi
end

n_dim = size(ics, 1)
backend = KernelAbstractions.get_backend(ics)
backend = networkbackend(ics)

n_iterations = Int(ceil((n_points - seq_length) / prediction_window))
# Array to store the predictions
Expand Down
2 changes: 1 addition & 1 deletion src/arrays/grassmann_lie_algebra_horizontal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end
Base.parent(A::GrassmannLieAlgHorMatrix) = (A.B, )
Base.size(A::GrassmannLieAlgHorMatrix) = (A.N, A.N)

KernelAbstractions.get_backend(B::GrassmannLieAlgHorMatrix) = KernelAbstractions.get_backend(B.B)
networkbackend(B::GrassmannLieAlgHorMatrix) = networkbackend(B.B)

function Base.getindex(A::GrassmannLieAlgHorMatrix{T}, i::Integer, j::Integer) where {T}
if i A.n
Expand Down
2 changes: 1 addition & 1 deletion src/arrays/lower_triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
function map_to_lo(A::AbstractMatrix{T}) where T
n = size(A, 1)
@assert size(A, 2) == n
backend = KernelAbstractions.get_backend(A)
backend = networkbackend(A)
S = KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2)
assign_Skew_val! = assign_Skew_val_kernel!(backend)
for i in 2:n
Expand Down
12 changes: 6 additions & 6 deletions src/arrays/skew_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ end

function Base.:+(A::SkewSymMatrix{T}, B::AbstractMatrix{T}) where T
@assert size(A) == size(B)
backend = KernelAbstractions.get_backend(B)
backend = networkbackend(B)
addition! = addition_kernel!(backend)
C = KernelAbstractions.allocate(backend, T, size(A)...)
addition!(C, A.S, B; ndrange = size(A))
Expand Down Expand Up @@ -215,7 +215,7 @@ LinearAlgebra.rmul!(C::SkewSymMatrix, α::Real) = mul!(C, C, α)
function Base.:*(A::SkewSymMatrix{T}, B::AbstractMatrix{T}) where T
m1, m2 = size(B)
@assert m1 == A.n
backend = KernelAbstractions.get_backend(A)
backend = networkbackend(A)
C = KernelAbstractions.allocate(backend, T, A.n, m2)

skew_mat_mul! = skew_mat_mul_kernel!(backend)
Expand Down Expand Up @@ -245,7 +245,7 @@ function Base.:*(A::SkewSymMatrix, b::AbstractVector{T}) where T
end

function Base.one(A::SkewSymMatrix{T}) where T
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n)
write_ones! = write_ones_kernel!(backend)
write_ones!(unit_matrix, ndrange=A.n)
Expand Down Expand Up @@ -290,8 +290,8 @@ function Base.zero(A::SkewSymMatrix)
SkewSymMatrix(zero(A.S), A.n)
end

function KernelAbstractions.get_backend(A::SkewSymMatrix)
KernelAbstractions.get_backend(A.S)
function networkbackend(A::SkewSymMatrix)
networkbackend(A.S)
end

function assign!(B::SkewSymMatrix{T}, C::SkewSymMatrix{T}) where T
Expand All @@ -311,7 +311,7 @@ function map_to_Skew(A::AbstractMatrix{T}) where T
n = size(A, 1)
@assert size(A, 2) == n
A_skew = T(.5)*(A - A')
backend = KernelAbstractions.get_backend(A)
backend = networkbackend(A)
S = if n != 1
KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2)
else
Expand Down
6 changes: 3 additions & 3 deletions src/arrays/stiefel_lie_algebra_horizontal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ function Base.zero(B::StiefelLieAlgHorMatrix)
)
end

function KernelAbstractions.get_backend(B::StiefelLieAlgHorMatrix)
KernelAbstractions.get_backend(B.B)
function networkbackend(B::StiefelLieAlgHorMatrix)
networkbackend(B.B)
end

# assign funciton; also implement this for other arrays!
Expand Down Expand Up @@ -302,7 +302,7 @@ function assign!(A::AbstractArray, B::AbstractArray)
end

function Base.one(B::StiefelLieAlgHorMatrix{T}) where T
backend = get_backend(B)
backend = networkbackend(B)
oneB = KernelAbstractions.zeros(backend, T, B.N, B.N)
write_ones! = write_ones_kernel!(backend)
write_ones!(oneB; ndrange = B.N)
Expand Down
8 changes: 4 additions & 4 deletions src/arrays/stiefel_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Extract necessary information from `A` and build an instance of `StiefelProjecti
Necessary information here referes to the backend, the data type and the size of the matrix.
"""
function StiefelProjection(A::AbstractMatrix{T}) where T
StiefelProjection(KernelAbstractions.get_backend(A), T, size(A)...)
StiefelProjection(networkbackend(A), T, size(A)...)
end

@doc raw"""
Expand Down Expand Up @@ -58,7 +58,7 @@ true
```
"""
function StiefelProjection(B::AbstractLieAlgHorMatrix{T}) where T
StiefelProjection(KernelAbstractions.get_backend(B), T, B.N, B.n)
StiefelProjection(networkbackend(B), T, B.N, B.n)
end

@kernel function assign_ones_for_stiefel_projection_kernel!(A::AbstractArray{T}) where T
Expand All @@ -79,6 +79,6 @@ Base.vcat(E::StiefelProjection{T}, A::AbstractVecOrMat{T}) where {T<:Number} = v
Base.hcat(A::AbstractVecOrMat{T}, E::StiefelProjection{T}) where {T<:Number} = hcat(A, E.A)
Base.hcat(E::StiefelProjection{T}, A::AbstractVecOrMat{T}) where {T<:Number} = hcat(E.A, A)

function KernelAbstractions.get_backend(E::StiefelProjection)
KernelAbstractions.get_backend(E.A)
function networkbackend(E::StiefelProjection)
networkbackend(E.A)
end
12 changes: 6 additions & 6 deletions src/arrays/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function map_to_S(A::AbstractMatrix{T}) where {T <: Number}
n = size(A, 1)
@assert size(A, 2) == n
A_sym = T(.5)*(A + A')
backend = KernelAbstractions.get_backend(A)
backend = networkbackend(A)
S = KernelAbstractions.zeros(backend, T, n*(n+1)÷2)
assign_S_val! = assign_S_val_kernel!(backend)
for i in 1:n
Expand Down Expand Up @@ -224,7 +224,7 @@ function LinearAlgebra.mul!(C::AbstractMatrix, A::SymmetricMatrix, B::AbstractMa
@assert A.n == size(B, 1)
@assert size(B, 2) == size(C, 2)
@assert A.n == size(C, 1)
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
symmetric_mat_mul! = symmetric_mat_mul_kernel!(backend)
symmetric_mat_mul!(C, A.S, B, A.n, ndrange=size(C))
end
Expand All @@ -244,13 +244,13 @@ end

function LinearAlgebra.mul!(c::AbstractVector, A::SymmetricMatrix, b::AbstractVector)
@assert A.n == length(c) == length(b)
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
symmetric_vector_mul! = symmetric_vector_mul_kernel!(backend)
symmetric_vector_mul!(c, A.S, b, A.n, ndrange=size(c))
end

function Base.:*(A::SymmetricMatrix{T}, B::AbstractMatrix{T}) where T
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
C = KernelAbstractions.allocate(backend, T, A.n, size(B, 2))
LinearAlgebra.mul!(C, A, B)
C
Expand All @@ -263,14 +263,14 @@ function Base.:*(A::SymmetricMatrix{T}, B::SymmetricMatrix{T}) where T
end

function Base.:*(A::SymmetricMatrix{T}, b::AbstractVector{T}) where T
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
c = KernelAbstractions.allocate(backend, T, A.n)
LinearAlgebra.mul!(c, A, b)
c
end

function Base.one(A::SymmetricMatrix{T}) where T
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n)
write_ones! = write_ones_kernel!(backend)
write_ones!(unit_matrix, ndrange=A.n)
Expand Down
6 changes: 3 additions & 3 deletions src/arrays/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ LinearAlgebra.mul!(C::AT, α::Real, A::AT) where AT <: AbstractTriangular = mul!
LinearAlgebra.rmul!(C::AT, α::Real) where AT <: AbstractTriangular = mul!(C, C, α)

function Base.one(A::AbstractTriangular{T}) where T
backend = KernelAbstractions.get_backend(A.S)
backend = networkbackend(A.S)
unit_matrix = KernelAbstractions.zeros(backend, T, A.n, A.n)
write_ones! = write_ones_kernel!(backend)
write_ones!(unit_matrix, ndrange=A.n)
Expand Down Expand Up @@ -132,8 +132,8 @@ function Base.zero(A::AT) where AT <: AbstractTriangular
AT(zero(A.S), A.n)
end

function KernelAbstractions.get_backend(A::AbstractTriangular)
KernelAbstractions.get_backend(A.S)
function networkbackend(A::AbstractTriangular)
networkbackend(A.S)
end

function assign!(B::AT, C::AT) where AT <: AbstractTriangular
Expand Down
2 changes: 1 addition & 1 deletion src/arrays/upper_triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
function map_to_up(A::AbstractMatrix{T}) where T
n = size(A, 1)
@assert size(A, 2) == n
backend = KernelAbstractions.get_backend(A)
backend = networkbackend(A)
S = KernelAbstractions.zeros(backend, T, n * (n - 1) ÷ 2)
assign_Skew_val! = assign_Skew_val_kernel!(backend)
for i in 2:n
Expand Down
10 changes: 5 additions & 5 deletions src/data_loader/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ GeometricMachineLearning.convert_input_and_batch_indices_to_array(dl, batch, bat
```
"""
function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, AT<:AbstractArray{T, 3}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}}
backend = KernelAbstractions.get_backend(dl.input.q)
backend = networkbackend(dl.input.q)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)
Expand All @@ -254,7 +254,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::
end

function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}}
backend = KernelAbstractions.get_backend(dl.input)
backend = networkbackend(dl.input)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)
Expand All @@ -275,7 +275,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT}, batch::
end

function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing, :RegularData}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, AT<:AbstractArray{T, 3}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}}
backend = KernelAbstractions.get_backend(dl.input.q)
backend = networkbackend(dl.input.q)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)
Expand All @@ -292,7 +292,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing,
end

function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, Nothing, :RegularData}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}}
backend = KernelAbstractions.get_backend(dl.input)
backend = networkbackend(dl.input)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)
Expand All @@ -318,7 +318,7 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, OT}, ::B
end

function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}}
backend = KernelAbstractions.get_backend(dl.input)
backend = networkbackend(dl.input)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)
Expand Down
14 changes: 7 additions & 7 deletions src/data_loader/data_loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ By default this inherits the autoencoder property form `dl`.
See the docstring for [`DataLoader(data::AbstractArray{<:Number, 3})`](@ref).
"""
function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type},
backend::KernelAbstractions.Backend=KernelAbstractions.get_backend(dl),
backend::KernelAbstractions.Backend=networkbackend(dl),
T::DataType=T1;
autoencoder = nothing
) where {T1, Type}
Expand All @@ -456,7 +456,7 @@ function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type},
end

new_input =
if backend == KernelAbstractions.get_backend(dl)
if backend == networkbackend(dl)
input
else
map_to_new_backend(input, backend)
Expand All @@ -473,7 +473,7 @@ function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing, Type},
end

function DataLoader(dl::DataLoader, T::DataType; kwargs...)
DataLoader(dl, KernelAbstractions.get_backend(dl), T; kwargs...)
DataLoader(dl, networkbackend(dl), T; kwargs...)
end

@doc raw"""
Expand All @@ -486,7 +486,7 @@ This needs an instance of [`DataLoader`](@ref) that stores the *test data*.
function accuracy(model::Chain, ps::NeuralNetworkParameters, dl::DataLoader{T, AT, BT}) where {T, T1<:Integer, AT<:AbstractArray{T}, BT<:AbstractArray{T1}}
output_tensor = model(dl.input, ps)
output_estimate = assign_output_estimate(output_tensor, dl.output_time_steps)
backend = KernelAbstractions.get_backend(output_estimate)
backend = networkbackend(output_estimate)
tensor_of_maximum_elements = KernelAbstractions.zeros(backend, T1, size(output_estimate)...)
ind = argmax(output_estimate, dims=1)
# get tensor of maximum elements
Expand All @@ -505,7 +505,7 @@ accuracy(nn::NeuralNetwork, dl::DataLoader) = accuracy(nn.model, nn.params, dl)

Base.eltype(::DataLoader{T}) where T = T

KernelAbstractions.get_backend(dl::DataLoader) = KernelAbstractions.get_backend(dl.input)
function KernelAbstractions.get_backend(dl::DataLoader{T, <:QPT{T}}) where T
KernelAbstractions.get_backend(dl.input.q)
networkbackend(dl::DataLoader) = networkbackend(dl.input)
function networkbackend(dl::DataLoader{T, <:QPT{T}}) where T
networkbackend(dl.input.q)
end
Loading

0 comments on commit eef6d29

Please sign in to comment.