From 4ddf1498331a5266bceafd5b590508a7523422e9 Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Tue, 12 Dec 2023 11:37:56 +0100 Subject: [PATCH] Added the bias layer for the LA-SympNet. --- src/GeometricMachineLearning.jl | 1 + src/architectures/sympnet.jl | 7 ++++--- src/layers/bias_layer.jl | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 src/layers/bias_layer.jl diff --git a/src/GeometricMachineLearning.jl b/src/GeometricMachineLearning.jl index 8d893d7a0..b357d839a 100644 --- a/src/GeometricMachineLearning.jl +++ b/src/GeometricMachineLearning.jl @@ -129,6 +129,7 @@ module GeometricMachineLearning include("optimizers/manifold_related/retractions.jl") include("layers/sympnets.jl") + include("layers/bias_layer.jl") include("layers/resnet.jl") include("layers/manifold_layer.jl") include("layers/stiefel_layer.jl") diff --git a/src/architectures/sympnet.jl b/src/architectures/sympnet.jl index 040f75714..8e017f125 100644 --- a/src/architectures/sympnet.jl +++ b/src/architectures/sympnet.jl @@ -59,7 +59,7 @@ end """ function Chain(arch::GSympNet{AT, true}) where {AT} layers = () - for i in 1:(arch.nhidden+1) + for _ in 1:(arch.nhidden+1) layers = (layers..., GradientLayerQ(arch.dim, arch.upscaling_dimension, arch.act), GradientLayerP(arch.dim, arch.upscaling_dimension, arch.act)) end Chain(layers...) @@ -67,7 +67,7 @@ end function Chain(arch::GSympNet{AT, false}) where {AT} layers = () - for i in 1:(arch.nhidden+1) + for _ in 1:(arch.nhidden+1) layers = (layers..., GradientLayerP(arch.dim, arch.upscaling_dimension, arch.act), GradientLayerQ(arch.dim, arch.upscaling_dimension, arch.act)) end Chain(layers...) @@ -78,10 +78,11 @@ Build a chain for an LASympnet for which `init_upper_linear` is `true` and `init """ function Chain(arch::LASympNet{AT, true, false}) where {AT} layers = () - for i in 1:arch.nhidden + for _ in 1:arch.nhidden for j in 1:(arch.depth) layers = isodd(j) ? (layers..., LinearLayerQ(arch.dim)) : (layers..., LinearLayerP(arch.dim)) end + layers = (layers..., BiasLayer(arch.dim)) layers = (layers..., ActivationLayerP(arch.dim, arch.activation)) layers = (layers..., ActivationLayerQ(arch.dim, arch.activation)) end diff --git a/src/layers/bias_layer.jl b/src/layers/bias_layer.jl new file mode 100644 index 000000000..922a32ed6 --- /dev/null +++ b/src/layers/bias_layer.jl @@ -0,0 +1,28 @@ +@doc raw""" +A *bias layer* that does nothing more than add a vector to the input. This is needed for *LA-SympNets*. +""" +struct BiasLayer{M, M} <: SympNetLayer{M, M} +end + +function BiasLayer(M::Int) + BiasLayer{M, M}() +end + +function initialparameters(backend::Backend, ::Type{T}, ::BiasLayer{M, M}; rng::AbstractRNG = Random.default_rng(), init_bias = GlorotUniform()) where {M, T} + q_part = KernelAbstractions.zeros(backend, T, M÷2) + p_part = KernelAbstractions.zeros(backend, T, M÷2) + init_bias(rng, q_part) + init_bias(rng, p_part) + return (q = q_part, p = p_part) +end + +function parameterlength(::BiasLayer{M, M}) where M + M +end + +(::BiasLayer{M, M})(z::NT, ps::NT) where {M, AT<:AbstractVector, NT<:NamedTuple{(:q, :p), Tuple{AT, AT}}} = (q = z.q + ps.q, p = z.p + ps.p) +(::BiasLayer{M, M})(z::NT1, ps::NT2) where {M, T, AT<:AbstractVector, BT<:Union{AbstractMatrix, AbstractArray{T, 3}}, NT1<:NamedTuple{(:q, :p), Tuple{AT, AT}}, NT2<:NamedTuple{(:q, :p), Tuple{BT, BT}}} = (q = z.q .+ ps.q, p = z.p .+ ps.p) + +function (d::BiasLayer{M, M})(z::AbstractArray, ps) where M + apply_layer_to_nt_and_return_array(z, d, ps) +end \ No newline at end of file