Skip to content

Commit

Permalink
Added the bias layer for the LA-SympNet.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Dec 12, 2023
1 parent 8d3424d commit 4ddf149
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ module GeometricMachineLearning
include("optimizers/manifold_related/retractions.jl")

include("layers/sympnets.jl")
include("layers/bias_layer.jl")
include("layers/resnet.jl")
include("layers/manifold_layer.jl")
include("layers/stiefel_layer.jl")
Expand Down
7 changes: 4 additions & 3 deletions src/architectures/sympnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ end
"""
function Chain(arch::GSympNet{AT, true}) where {AT}
layers = ()
for i in 1:(arch.nhidden+1)
for _ in 1:(arch.nhidden+1)
layers = (layers..., GradientLayerQ(arch.dim, arch.upscaling_dimension, arch.act), GradientLayerP(arch.dim, arch.upscaling_dimension, arch.act))
end
Chain(layers...)
end

function Chain(arch::GSympNet{AT, false}) where {AT}
layers = ()
for i in 1:(arch.nhidden+1)
for _ in 1:(arch.nhidden+1)

Check warning on line 70 in src/architectures/sympnet.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/sympnet.jl#L70

Added line #L70 was not covered by tests
layers = (layers..., GradientLayerP(arch.dim, arch.upscaling_dimension, arch.act), GradientLayerQ(arch.dim, arch.upscaling_dimension, arch.act))
end
Chain(layers...)
Expand All @@ -78,10 +78,11 @@ Build a chain for an LASympnet for which `init_upper_linear` is `true` and `init
"""
function Chain(arch::LASympNet{AT, true, false}) where {AT}
layers = ()
for i in 1:arch.nhidden
for _ in 1:arch.nhidden

Check warning on line 81 in src/architectures/sympnet.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/sympnet.jl#L81

Added line #L81 was not covered by tests
for j in 1:(arch.depth)
layers = isodd(j) ? (layers..., LinearLayerQ(arch.dim)) : (layers..., LinearLayerP(arch.dim))
end
layers = (layers..., BiasLayer(arch.dim))

Check warning on line 85 in src/architectures/sympnet.jl

View check run for this annotation

Codecov / codecov/patch

src/architectures/sympnet.jl#L85

Added line #L85 was not covered by tests
layers = (layers..., ActivationLayerP(arch.dim, arch.activation))
layers = (layers..., ActivationLayerQ(arch.dim, arch.activation))
end
Expand Down
28 changes: 28 additions & 0 deletions src/layers/bias_layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@doc raw"""
A *bias layer* that does nothing more than add a vector to the input. This is needed for *LA-SympNets*.
"""
struct BiasLayer{M, M} <: SympNetLayer{M, M}
end

function BiasLayer(M::Int)
BiasLayer{M, M}()

Check warning on line 8 in src/layers/bias_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/bias_layer.jl#L7-L8

Added lines #L7 - L8 were not covered by tests
end

function initialparameters(backend::Backend, ::Type{T}, ::BiasLayer{M, M}; rng::AbstractRNG = Random.default_rng(), init_bias = GlorotUniform()) where {M, T}
q_part = KernelAbstractions.zeros(backend, T, M÷2)
p_part = KernelAbstractions.zeros(backend, T, M÷2)
init_bias(rng, q_part)
init_bias(rng, p_part)
return (q = q_part, p = p_part)

Check warning on line 16 in src/layers/bias_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/bias_layer.jl#L11-L16

Added lines #L11 - L16 were not covered by tests
end

function parameterlength(::BiasLayer{M, M}) where M
M

Check warning on line 20 in src/layers/bias_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/bias_layer.jl#L19-L20

Added lines #L19 - L20 were not covered by tests
end

(::BiasLayer{M, M})(z::NT, ps::NT) where {M, AT<:AbstractVector, NT<:NamedTuple{(:q, :p), Tuple{AT, AT}}} = (q = z.q + ps.q, p = z.p + ps.p)
(::BiasLayer{M, M})(z::NT1, ps::NT2) where {M, T, AT<:AbstractVector, BT<:Union{AbstractMatrix, AbstractArray{T, 3}}, NT1<:NamedTuple{(:q, :p), Tuple{AT, AT}}, NT2<:NamedTuple{(:q, :p), Tuple{BT, BT}}} = (q = z.q .+ ps.q, p = z.p .+ ps.p)

Check warning on line 24 in src/layers/bias_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/bias_layer.jl#L23-L24

Added lines #L23 - L24 were not covered by tests

function (d::BiasLayer{M, M})(z::AbstractArray, ps) where M
apply_layer_to_nt_and_return_array(z, d, ps)

Check warning on line 27 in src/layers/bias_layer.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/bias_layer.jl#L26-L27

Added lines #L26 - L27 were not covered by tests
end

0 comments on commit 4ddf149

Please sign in to comment.