Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer integrator #131

Merged
merged 11 commits into from
Apr 15, 2024
4 changes: 2 additions & 2 deletions docs/src/tutorials/sympnet_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ ics = (q=qp_data.q[:,1], p=qp_data.p[:,1])
steps_to_plot = 200

#predictions
la_trajectory = Iterate_Sympnet(la_nn, ics; n_points = steps_to_plot)
g_trajectory = Iterate_Sympnet(g_nn, ics; n_points = steps_to_plot)
la_trajectory = iterate(la_nn, ics; n_points = steps_to_plot)
g_trajectory = iterate(g_nn, ics; n_points = steps_to_plot)

using Plots
p2 = plot(qp_data.q'[1:steps_to_plot], qp_data.p'[1:steps_to_plot], label="training data")
Expand Down
7 changes: 6 additions & 1 deletion src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ module GeometricMachineLearning
export PSDLayer
export MultiHeadAttention
export VolumePreservingAttention
export ResNet
export ResNetLayer
export Transformer
export Classification
export VolumePreservingLowerLayer
Expand Down Expand Up @@ -246,6 +246,9 @@ module GeometricMachineLearning

#INCLUDE ARCHITECTURES
include("architectures/neural_network_integrator.jl")
include("architectures/resnet.jl")
include("architectures/transformer_integrator.jl")
include("architectures/regular_transformer_integrator.jl")
include("architectures/sympnet.jl")
include("architectures/autoencoder.jl")
include("architectures/fixed_width_network.jl")
Expand All @@ -265,6 +268,8 @@ module GeometricMachineLearning
export RecurrentNeuralNetwork
export LSTMNeuralNetwork
export ClassificationTransformer
export ResNet
export RegularTransformerIntegrator
export VolumePreservingFeedForward

export train!, apply!, jacobian!
Expand Down
42 changes: 42 additions & 0 deletions src/architectures/regular_transformer_integrator.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a typo: "the defualt is"

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
@doc raw"""
The regular transformer used as an integrator (multi-step method).

The constructor is called with the following arguments:
- `sys_dim::Int`
- `transformer_dim::Int`: the default is `transformer_dim = sys_dim`.
- `n_blocks::Int`: The default is `1`.
- `n_heads::Int`: the number of heads in the multihead attentio layer (default is `n_heads = sys_dim`)
- `L::Int` the number of transformer blocks (default is `L = 2`).
- `upscaling_activation`: by default identity
- `resnet_activation`: by default tanh
- `add_connection:Bool=true` (keyword argument): if the input should be added to the output.
"""
struct RegularTransformerIntegrator{AT1, AT2} <: TransformerIntegrator
sys_dim::Int
transformer_dim::Int
n_heads::Int
n_blocks::Int
L::Int
upsacling_activation::AT1
resnet_activation::AT2
add_connection::Bool
end

# function RegularTransformerIntegrator(sys_dim::Int, transformer_dim::Int = sys_dim, n_heads::Int = sys_dim, n_blocks = 1, L::Int = 2, upscaling_activation = identity, resnet_activation = tanh; add_connection::Bool = true)
# RegularTransformerIntegrator{typeof(upscaling_activation), typeof(resnet_activation)}(sys_dim, transformer_dim, n_heads, n_blocks, L, upscaling_activation, resnet_activation, add_connection)
# end

function RegularTransformerIntegrator(sys_dim::Int, transformer_dim::Int = sys_dim, n_heads::Int = sys_dim; n_blocks = 1, L::Int = 2, upscaling_activation = identity, resnet_activation = tanh, add_connection::Bool = true)
RegularTransformerIntegrator(sys_dim, transformer_dim, n_heads, n_blocks, L, upscaling_activation, resnet_activation, add_connection)
end

function Chain(arch::RegularTransformerIntegrator)
layers = arch.sys_dim == arch.transformer_dim ? () : (Dense(arch.sys_dim, arch.transformer_dim, arch.upsacling_activation), )
for _ in 1:arch.L
layers = (layers..., MultiHeadAttention(arch.transformer_dim, arch.n_heads; add_connection = arch.add_connection))
layers = (layers..., Chain(ResNet(arch.transformer_dim, arch.n_blocks, arch.resnet_activation)).layers...)
end
layers = arch.sys_dim == arch.transformer_dim ? layers : (layers..., Dense(arch.transformer_dim, arch.sys_dim, identity))

Chain(layers...)
end
18 changes: 18 additions & 0 deletions src/architectures/resnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
struct ResNet{AT} <: NeuralNetworkIntegrator
sys_dim::Int
n_blocks::Int
activation::AT
end

function Chain(arch::ResNet{AT}) where AT
layers = ()
for _ in 1:arch.n_blocks
# nonlinear layers
layers = (layers..., ResNetLayer(arch.sys_dim, arch.activation; use_bias=true))
end

# linear layers for the output
layers = (layers..., ResNetLayer(arch.sys_dim, identity; use_bias=true))

Chain(layers...)
end
33 changes: 1 addition & 32 deletions src/architectures/sympnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SympNet type encompasses GSympNets and LASympnets.
TODO:
-[ ] add bias to `LASympNet`!
"""
abstract type SympNet{AT} <: Architecture end
abstract type SympNet{AT} <: NeuralNetworkIntegrator end

@doc raw"""
`LASympNet` is called with **a single input argument**, the **system dimension**, or with an instance of `DataLoader`. Optional input arguments are:
Expand Down Expand Up @@ -144,35 +144,4 @@ function Chain(arch::LASympNet{AT, true, true}) where {AT}
layers = isodd(j) ? (layers..., LinearLayerQ(arch.dim)) : (layers..., LinearLayerP(arch.dim))
end
Chain(layers...)
end

@doc raw"""
This function computes a trajectory for a SympNet that has already been trained for valuation purposes.

It takes as input:
- `nn`: a `NeuralNetwork` (that has been trained).
- `ics`: initial conditions (a `NamedTuple` of two vectors)
"""
function Iterate_Sympnet(nn::NeuralNetwork{<:SympNet}, ics::NamedTuple{(:q, :p), Tuple{AT, AT}}; n_points = 100) where {T, AT<:AbstractVector{T}}

n_dim = length(ics.q)
backend = KernelAbstractions.get_backend(ics.q)

# Array to store the predictions
q_valuation = KernelAbstractions.allocate(backend, T, n_dim, n_points)
p_valuation = KernelAbstractions.allocate(backend, T, n_dim, n_points)

# Initialisation
@views q_valuation[:,1] = ics.q
@views p_valuation[:,1] = ics.p

#Computation of phase space
@views for i in 2:n_points
qp_temp = (q=q_valuation[:,i-1], p=p_valuation[:,i-1])
qp_prediction = nn(qp_temp)
q_valuation[:,i] = qp_prediction.q
p_valuation[:,i] = qp_prediction.p
end

(q=q_valuation, p=p_valuation)
end
70 changes: 70 additions & 0 deletions src/architectures/transformer_integrator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@doc raw"""
Encompasses various transformer architectures, such as the structure-preserving transformer and the linear symplectic transformer.
"""
abstract type TransformerIntegrator <: Architecture end

struct DummyTransformer <: TransformerIntegrator
seq_length::Int
end

@doc raw"""
This function computes a trajectory for a Transformer that has already been trained for valuation purposes.

It takes as input:
- `nn`: a `NeuralNetwork` (that has been trained).
- `ics`: initial conditions (a matrix in ``\mathbb{R}^{2n\times\mathtt{seq\_length}}`` or `NamedTuple` of two matrices in ``\mathbb{R}^{n\times\mathtt{seq\_length}}``)
- `n_points::Int=100` (keyword argument): The number of steps for which we run the prediction.
- `prediction_window::Int=size(ics.q, 2)`: The prediction window (i.e. the number of steps we predict into the future) is equal to the sequence length (i.e. the number of input time steps) by default.
"""
function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::NamedTuple{(:q, :p), Tuple{AT, AT}}; n_points::Int = 100, prediction_window::Union{Nothing, Int}=size(ics.q, 2)) where {T, AT<:AbstractMatrix{T}}

Check warning on line 19 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L19

Added line #L19 was not covered by tests

seq_length = nn.architecture.seq_length

Check warning on line 21 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L21

Added line #L21 was not covered by tests

n_dim = size(ics.q, 1)
backend = KernelAbstractions.get_backend(ics.q)

Check warning on line 24 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L23-L24

Added lines #L23 - L24 were not covered by tests

n_iterations = Int(ceil((n_points - seq_length) / prediction_window))

Check warning on line 26 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L26

Added line #L26 was not covered by tests
# Array to store the predictions
q_valuation = KernelAbstractions.allocate(backend, T, n_dim, seq_length + n_iterations * prediction_window)
p_valuation = KernelAbstractions.allocate(backend, T, n_dim, seq_length + n_iterations * prediction_window)

Check warning on line 29 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L28-L29

Added lines #L28 - L29 were not covered by tests

# Initialisation
q_valuation[:,1:seq_length] = ics.q
p_valuation[:,1:seq_length] = ics.p

Check warning on line 33 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L32-L33

Added lines #L32 - L33 were not covered by tests

# iteration in phase space
@views for i in 1:n_iterations
start_index = (i - 1) * prediction_window + 1
@views qp_temp = (q = q_valuation[:, start_index:(start_index + seq_length - 1)], p = p_valuation[:, start_index:(start_index + seq_length - 1)])
qp_prediction = nn(qp_temp)
q_valuation[seq_length + (i - 1) * prediction_window, seq_length + i * prediction_window] = qp_prediction.q[:, (seq_length - prediction_window + 1):end]
p_valuation[seq_length + (i - 1) * prediction_window, seq_length + i * prediction_window] = qp_prediction.p[:, (seq_length - prediction_window + 1):end]
end

Check warning on line 42 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L36-L42

Added lines #L36 - L42 were not covered by tests

(q=q_valuation[:, 1:n_points], p=p_valuation[:, 1:n_points])

Check warning on line 44 in src/architectures/transformer_integrator.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/transformer_integrator.jl#L44

Added line #L44 was not covered by tests
end

function Base.iterate(nn::NeuralNetwork{<:TransformerIntegrator}, ics::AT; n_points::Int = 100, prediction_window::Union{Nothing, Int} = size(ics, 2)) where {T, AT<:AbstractMatrix{T}}

seq_length = typeof(nn.architecture) <: RegularTransformerIntegrator ? prediction_window : nn.architecture.seq_length

n_dim = size(ics, 1)
backend = KernelAbstractions.get_backend(ics)

n_iterations = Int(ceil((n_points - seq_length) / prediction_window))
# Array to store the predictions
valuation = KernelAbstractions.allocate(backend, T, n_dim, seq_length + n_iterations * prediction_window)

# Initialisation
valuation[:,1:seq_length] = ics

# iteration in phase space
@views for i in 1:n_iterations
start_index = (i - 1) * prediction_window + 1
temp = valuation[:, start_index:(start_index + seq_length - 1)]
prediction = nn(copy(temp))
valuation[:, (seq_length + (i - 1) * prediction_window + 1):(seq_length + i * prediction_window)] = prediction[:, (seq_length - prediction_window + 1):end]
end

valuation[:, 1:n_points]
end
47 changes: 39 additions & 8 deletions src/arrays/symplectic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,63 @@
\end{pmatrix}
```
"""
function SymplecticPotential(backend, n2::Int, T::DataType=Float64)
struct SymplecticPotential{T, AT} <: AbstractMatrix{T}
J::AT
n::Int
end

Base.getindex(𝕁::SymplecticPotential, i, j) = getindex(𝕁.J, i, j)

Base.size(𝕁::SymplecticPotential) = size(𝕁.J)

function SymplecticPotential(backend::Backend, n2::Int, T::DataType)
@assert iseven(n2)
n = n2÷2
J = KernelAbstractions.zeros(backend, T, 2*n, 2*n)
assign_ones_for_symplectic_potential! = assign_ones_for_symplectic_potential_kernel!(backend)
assign_ones_for_symplectic_potential!(J, n, ndrange=n)
J
assign_ones_for_symplectic_potential!(J, n, ndrange=n2)

SymplecticPotential{T, typeof(J)}(J, n)
end

SymplecticPotential(n::Int, T::DataType=Float64) = SymplecticPotential(CPU(), n, T)
SymplecticPotential(bakend, T::DataType, n::Int) = SymplecticPotential(backend, n, T)
SymplecticPotential(n2::Int, T::DataType) = SymplecticPotential(CPU(), n2, T)

SymplecticPotential(T::DataType, n::Int) = SymplecticPotential(n, T)
SymplecticPotential(n2::Int) = SymplecticPotential(n2, Float64)

Check warning on line 36 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L36

Added line #L36 was not covered by tests

SymplecticPotential(backend::Backend, n2::Int) = SymplecticPotential(backend, n2, Float32)

Check warning on line 38 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L38

Added line #L38 was not covered by tests

SymplecticPotential(backend::CPU, n2::Int) = SymplecticPotential(backend, n2, Float64)

Check warning on line 40 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L40

Added line #L40 was not covered by tests

@kernel function assign_ones_for_symplectic_potential_kernel!(J::AbstractMatrix{T}, n::Int) where T
i = @index(Global)
J[map_index_for_symplectic_potential(i, n)...] = i ≤ n ? one(T) : -one(T)
end

Base.:*(𝕁::SymplecticPotential{T}, v::NamedTuple{(:q, :p), Tuple{AT, AT}}) where {T, AT <: AbstractVecOrMat{T}} = (q = v.p, p = -v.q)

function _vcat(v::NamedTuple{(:q, :p), Tuple{AT, AT}}) where {AT <: AbstractArray}
vcat(v.q, v.p)
end

Base.:*(𝕁::SymplecticPotential{T}, v::AbstractVector{T}) where T = _vcat(𝕁 * assign_q_and_p(v, 𝕁.n))
Base.:*(𝕁::SymplecticPotential{T}, v::AbstractMatrix{T}) where T = _vcat(𝕁 * assign_q_and_p(v, 𝕁.n))

Check warning on line 54 in src/arrays/symplectic.jl

View check run for this annotation

Codecov / codecov/patch

src/arrays/symplectic.jl#L54

Added line #L54 was not covered by tests


function (𝕁::SymplecticPotential{T})(v₁::NT, v₂::NT) where {T, AT <: AbstractVector{T}, NT <: NamedTuple{(:q, :p), Tuple{AT, AT}}}
v₁.q' * v₂.p - v₁.p' * v₂.q
end

function (𝕁::SymplecticPotential{T})(v₁::AbstractVector{T}, v₂::AbstractVector{T}) where T
𝕁(assign_q_and_p(v₁, 𝕁.n), assign_q_and_p(v₂, 𝕁.n))
end

"""
This assigns the right index for the symplectic potential. To be used with `assign_ones_for_symplectic_potential_kernel!`.
"""
function map_index_for_symplectic_potential(i::Int, n::Int)
if i ≤ n
return (i, i+n)
return (i, i + n)
else
return (i, i-n)
return (i, i - n)
end
end
7 changes: 2 additions & 5 deletions src/layers/multi_head_attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function compute_output_of_mha(d::MultiHeadAttention{M, M}, x::AbstractMatrix{T}
output = typeof(x)(zeros(T, 0, input_length))
for i in 1:d.n_heads
key = Symbol("head_"*string(i))
output = vcat(output, ps.PV[key]'*x*softmax((ps.PQ[key]'*x)'*(ps.PK[key]'*x)/T(sqrt(dim))))
output = vcat(output, ps.PV[key]' * x * softmax((ps.PQ[key]' * x)' * (ps.PK[key]' * x) / T(sqrt(dim))))
end
output
end
Expand All @@ -111,7 +111,6 @@ function compute_output_of_mha(d::MultiHeadAttention{M, M}, x::AbstractArray{T,

single_head_output = tensor_tensor_mul(V_tensor, softmax(QK_tensor/T(sqrt(dim))))
output = vcat(output, single_head_output)
# KernelAbstractions.synchronize(backend)
end
output
end
Expand All @@ -125,9 +124,7 @@ function (d::MultiHeadAttention{M, M, Stiefel, Retraction, false})(x::AbstractAr
end

import ChainRules
"""
This has to be extended to tensors; you should probably do a PR in ChainRules for this.
"""
# type pyracy!
function ChainRules._adjoint_mat_pullback(y::AbstractArray{T, 3}, proj) where T
(NoTangent(), proj(tensor_transpose(y)))
end
Expand Down
18 changes: 9 additions & 9 deletions src/layers/resnet.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
struct ResNet{M, N, use_bias, F1} <: AbstractExplicitLayer{M, N}
struct ResNetLayer{M, N, use_bias, F1} <: AbstractExplicitLayer{M, N}
activation::F1
end

function ResNet(dim::IT, activation=identity; use_bias::Bool=true) where {IT<:Int}
return ResNet{dim, dim, use_bias, typeof(activation)}(activation)
function ResNetLayer(dim::IT, activation=identity; use_bias::Bool=true) where {IT<:Int}
return ResNetLayer{dim, dim, use_bias, typeof(activation)}(activation)
end

function initialparameters(::ResNet{M, M, use_bias}, backend::KernelAbstractions.Backend, T::Type; rng::Random.AbstractRNG=Random.default_rng(), init_weight = GlorotUniform(), init_bias = ZeroInitializer()) where {M, use_bias}
function initialparameters(::ResNetLayer{M, M, use_bias}, backend::KernelAbstractions.Backend, T::Type; rng::Random.AbstractRNG=Random.default_rng(), init_weight = GlorotUniform(), init_bias = ZeroInitializer()) where {M, use_bias}
if use_bias
weight = KernelAbstractions.allocate(backend, T, M, M)
bias = KernelAbstractions.allocate(backend, T, M)
Expand All @@ -21,23 +21,23 @@ function initialparameters(::ResNet{M, M, use_bias}, backend::KernelAbstractions
end
end

function parameterlength(::ResNet{M, M, use_bias}) where {M, use_bias}
function parameterlength(::ResNetLayer{M, M, use_bias}) where {M, use_bias}
return use_bias ? M * (M + 1) : M * M
end

@inline function (d::ResNet{M, M, true})(x::AbstractVecOrMat, ps::NamedTuple) where {M}
@inline function (d::ResNetLayer{M, M, true})(x::AbstractVecOrMat, ps::NamedTuple) where {M}
return x + d.activation.(ps.weight * x .+ ps.bias)
end

@inline function (d::ResNet{M, M, false})(x::AbstractVecOrMat, ps::NamedTuple) where {M}
@inline function (d::ResNetLayer{M, M, false})(x::AbstractVecOrMat, ps::NamedTuple) where {M}
return x + d.activation.(ps.weight * x)
end

@inline function (d::ResNet{M, M, false})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, T}
@inline function (d::ResNetLayer{M, M, false})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, T}
return x + d.activation.(mat_tensor_mul(ps.weight, x))
end

@inline function (d::ResNet{M, M, true})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, T}
@inline function (d::ResNetLayer{M, M, true})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, T}
return x + d.activation.(mat_tensor_mul(ps.weight, x) .+ ps.bias)
end

Expand Down
2 changes: 1 addition & 1 deletion src/layers/transformer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function Transformer(dim::Integer, n_heads::Integer, L::Integer;
layers = ()
for _ in 1:L
layers = (layers..., MultiHeadAttention(dim, n_heads, Stiefel=Stiefel, retraction=retraction, add_connection=add_connection),
ResNet(dim, activation; use_bias=use_bias) )
ResNetLayer(dim, activation; use_bias=use_bias) )
end

Chain(layers...)
Expand Down
Loading
Loading