Skip to content

Commit

Permalink
Now internally only working with NamedTuples instead of Tuples.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Dec 12, 2023
1 parent b345600 commit a205414
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
14 changes: 3 additions & 11 deletions src/layers/psd_like_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,10 @@ function initialparameters(backend::KernelAbstractions.Backend, T::Type, ::PSDLa
(weight = N > M ? rand(backend, rng, StiefelManifold{T}, N÷2, M÷2) : rand(backend, rng, StiefelManifold{T}, M÷2, N÷2), )
end

function (::PSDLayer{M, N})(x::AbstractVecOrMat, ps::NamedTuple) where {M, N}
function (::PSDLayer{M, N})(x::AbstractArray, ps::NamedTuple) where {M, N}
dim = size(x, 1)
@assert dim == M

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(ps.weight*q, ps.weight*p) : vcat(ps.weight'*q, ps.weight'*p)
end

function (::PSDLayer{M, N})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, N, T}
dim = size(x, 1)
@assert dim == M

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(mat_tensor_mul(ps.weight,q), mat_tensor_mul(ps.weight,p)) : vcat(mat_tensor_mul(ps.weight', q), mat_tensor_mul(ps.weight', p))
qp = assign_q_and_p(x, dim÷2)
N > M ? vcat(custom_mat_mul(ps.weight, qp.q), custom_mat_mul(ps.weight, qp.p)) : vcat(custom_mat_mul(ps.weight', qp.q), custom_mat_mul(ps.weight', qp.p))
end
7 changes: 5 additions & 2 deletions src/layers/sympnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ function parameterlength(::LinearLayer{M, M}) where {M}
(M÷2)*(M÷2+1)÷2
end

@doc raw"""
Multiplies a matrix with a vector, a matrix or a tensor.
"""
custom_mat_mul(weight::AbstractMatrix, x::AbstractVecOrMat) = weight*x
function custom_mat_mul(weight::AbstractMatrix, x::AbstractArray{T, 3}) where T
mat_tensor_mul(weight, x)
Expand Down Expand Up @@ -236,8 +239,8 @@ It converts the Array to a `NamedTuple` (via `assign_q_and_p`), then calls the S
"""
function apply_layer_to_nt_and_return_array(x::AbstractArray, d::SympNetLayer{M, M}, ps) where {M}
N2 = size(x, 1)÷2
q, p = assign_q_and_p(x, N2)
output = d((q=q, p=p), ps)
qp = assign_q_and_p(x, N2)
output = d(qp, ps)
return vcat(output.q, output.p)
end

Expand Down

0 comments on commit a205414

Please sign in to comment.