Skip to content

Commit

Permalink
Merge pull request #95 from JuliaGNI/convert_tuple_functions_to_named…
Browse files Browse the repository at this point in the history
…_tuple_functions

Internally now only working with `NamedTuples`, not with `Tuples` when dealing with `q` and `p`.
  • Loading branch information
michakraus authored Dec 12, 2023
2 parents 858d440 + a205414 commit 98083dd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/kernels/assign_q_and_p.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function assign_q_and_p(x::AbstractVector, N::Int)
p_kernel! = assign_second_half!(backend)
q_kernel!(q, x, ndrange=size(q))
p_kernel!(p, x, N, ndrange=size(p))
(q, p)
(q=q, p=p)
end

function assign_q_and_p(x::AbstractMatrix, N::Int)
Expand All @@ -56,7 +56,7 @@ function assign_q_and_p(x::AbstractMatrix, N::Int)
p_kernel! = assign_second_half!(backend)
q_kernel!(q, x, ndrange=size(q))
p_kernel!(p, x, N, ndrange=size(p))
(q, p)
(q=q, p=p)
end

function assign_q_and_p(x::AbstractArray{T, 3}, N::Int) where T
Expand All @@ -67,5 +67,5 @@ function assign_q_and_p(x::AbstractArray{T, 3}, N::Int) where T
p_kernel! = assign_second_half!(backend)
q_kernel!(q, x, ndrange=size(q))
p_kernel!(p, x, N, ndrange=size(p))
(q, p)
(q=q, p=p)
end
4 changes: 2 additions & 2 deletions src/kernels/kernel_ad_routines/assign_q_and_p.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ This implements the custom pullback for assign_q_and_p
"""

function ChainRulesCore.rrule(::typeof(assign_q_and_p), x::AbstractArray, N::Integer)
q, p = assign_q_and_p(x, N)
qp = assign_q_and_p(x, N)
function assign_q_and_p_pullback(qp_diff)
= NoTangent()
concat = @thunk vcat(qp_diff...)
return f̄, concat, NoTangent()
end
return (q, p), assign_q_and_p_pullback
return qp, assign_q_and_p_pullback
end
14 changes: 3 additions & 11 deletions src/layers/psd_like_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,10 @@ function initialparameters(backend::KernelAbstractions.Backend, T::Type, ::PSDLa
(weight = N > M ? rand(backend, rng, StiefelManifold{T}, N÷2, M÷2) : rand(backend, rng, StiefelManifold{T}, M÷2, N÷2), )
end

function (::PSDLayer{M, N})(x::AbstractVecOrMat, ps::NamedTuple) where {M, N}
function (::PSDLayer{M, N})(x::AbstractArray, ps::NamedTuple) where {M, N}
dim = size(x, 1)
@assert dim == M

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(ps.weight*q, ps.weight*p) : vcat(ps.weight'*q, ps.weight'*p)
end

function (::PSDLayer{M, N})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, N, T}
dim = size(x, 1)
@assert dim == M

q, p = assign_q_and_p(x, dim÷2)
N > M ? vcat(mat_tensor_mul(ps.weight,q), mat_tensor_mul(ps.weight,p)) : vcat(mat_tensor_mul(ps.weight', q), mat_tensor_mul(ps.weight', p))
qp = assign_q_and_p(x, dim÷2)
N > M ? vcat(custom_mat_mul(ps.weight, qp.q), custom_mat_mul(ps.weight, qp.p)) : vcat(custom_mat_mul(ps.weight', qp.q), custom_mat_mul(ps.weight', qp.p))
end
7 changes: 5 additions & 2 deletions src/layers/sympnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ function parameterlength(::LinearLayer{M, M}) where {M}
(M÷2)*(M÷2+1)÷2
end

@doc raw"""
Multiplies a matrix with a vector, a matrix or a tensor.
"""
custom_mat_mul(weight::AbstractMatrix, x::AbstractVecOrMat) = weight*x
function custom_mat_mul(weight::AbstractMatrix, x::AbstractArray{T, 3}) where T
mat_tensor_mul(weight, x)
Expand Down Expand Up @@ -236,8 +239,8 @@ It converts the Array to a `NamedTuple` (via `assign_q_and_p`), then calls the S
"""
function apply_layer_to_nt_and_return_array(x::AbstractArray, d::SympNetLayer{M, M}, ps) where {M}
N2 = size(x, 1)÷2
q, p = assign_q_and_p(x, N2)
output = d((q=q, p=p), ps)
qp = assign_q_and_p(x, N2)
output = d(qp, ps)
return vcat(output.q, output.p)
end

Expand Down

0 comments on commit 98083dd

Please sign in to comment.