diff --git a/src/kernels/assign_q_and_p.jl b/src/kernels/assign_q_and_p.jl index 23dc31db0..5b6cbba37 100644 --- a/src/kernels/assign_q_and_p.jl +++ b/src/kernels/assign_q_and_p.jl @@ -45,7 +45,7 @@ function assign_q_and_p(x::AbstractVector, N::Int) p_kernel! = assign_second_half!(backend) q_kernel!(q, x, ndrange=size(q)) p_kernel!(p, x, N, ndrange=size(p)) - (q, p) + (q=q, p=p) end function assign_q_and_p(x::AbstractMatrix, N::Int) @@ -56,7 +56,7 @@ function assign_q_and_p(x::AbstractMatrix, N::Int) p_kernel! = assign_second_half!(backend) q_kernel!(q, x, ndrange=size(q)) p_kernel!(p, x, N, ndrange=size(p)) - (q, p) + (q=q, p=p) end function assign_q_and_p(x::AbstractArray{T, 3}, N::Int) where T @@ -67,5 +67,5 @@ function assign_q_and_p(x::AbstractArray{T, 3}, N::Int) where T p_kernel! = assign_second_half!(backend) q_kernel!(q, x, ndrange=size(q)) p_kernel!(p, x, N, ndrange=size(p)) - (q, p) + (q=q, p=p) end \ No newline at end of file diff --git a/src/kernels/kernel_ad_routines/assign_q_and_p.jl b/src/kernels/kernel_ad_routines/assign_q_and_p.jl index 18a47fc2a..bf3546a3e 100644 --- a/src/kernels/kernel_ad_routines/assign_q_and_p.jl +++ b/src/kernels/kernel_ad_routines/assign_q_and_p.jl @@ -3,11 +3,11 @@ This implements the custom pullback for assign_q_and_p """ function ChainRulesCore.rrule(::typeof(assign_q_and_p), x::AbstractArray, N::Integer) - q, p = assign_q_and_p(x, N) + qp = assign_q_and_p(x, N) function assign_q_and_p_pullback(qp_diff) f̄ = NoTangent() concat = @thunk vcat(qp_diff...) return f̄, concat, NoTangent() end - return (q, p), assign_q_and_p_pullback + return qp, assign_q_and_p_pullback end \ No newline at end of file diff --git a/src/layers/psd_like_layer.jl b/src/layers/psd_like_layer.jl index edd293ba4..7553b8266 100644 --- a/src/layers/psd_like_layer.jl +++ b/src/layers/psd_like_layer.jl @@ -29,18 +29,10 @@ function initialparameters(backend::KernelAbstractions.Backend, T::Type, ::PSDLa (weight = N > M ? rand(backend, rng, StiefelManifold{T}, N÷2, M÷2) : rand(backend, rng, StiefelManifold{T}, M÷2, N÷2), ) end -function (::PSDLayer{M, N})(x::AbstractVecOrMat, ps::NamedTuple) where {M, N} +function (::PSDLayer{M, N})(x::AbstractArray, ps::NamedTuple) where {M, N} dim = size(x, 1) @assert dim == M - q, p = assign_q_and_p(x, dim÷2) - N > M ? vcat(ps.weight*q, ps.weight*p) : vcat(ps.weight'*q, ps.weight'*p) -end - -function (::PSDLayer{M, N})(x::AbstractArray{T, 3}, ps::NamedTuple) where {M, N, T} - dim = size(x, 1) - @assert dim == M - - q, p = assign_q_and_p(x, dim÷2) - N > M ? vcat(mat_tensor_mul(ps.weight,q), mat_tensor_mul(ps.weight,p)) : vcat(mat_tensor_mul(ps.weight', q), mat_tensor_mul(ps.weight', p)) + qp = assign_q_and_p(x, dim÷2) + N > M ? vcat(custom_mat_mul(ps.weight, qp.q), custom_mat_mul(ps.weight, qp.p)) : vcat(custom_mat_mul(ps.weight', qp.q), custom_mat_mul(ps.weight', qp.p)) end \ No newline at end of file diff --git a/src/layers/sympnets.jl b/src/layers/sympnets.jl index c4c7c109d..d749bc7e2 100644 --- a/src/layers/sympnets.jl +++ b/src/layers/sympnets.jl @@ -182,6 +182,9 @@ function parameterlength(::LinearLayer{M, M}) where {M} (M÷2)*(M÷2+1)÷2 end +@doc raw""" +Multiplies a matrix with a vector, a matrix or a tensor. +""" custom_mat_mul(weight::AbstractMatrix, x::AbstractVecOrMat) = weight*x function custom_mat_mul(weight::AbstractMatrix, x::AbstractArray{T, 3}) where T mat_tensor_mul(weight, x) @@ -236,8 +239,8 @@ It converts the Array to a `NamedTuple` (via `assign_q_and_p`), then calls the S """ function apply_layer_to_nt_and_return_array(x::AbstractArray, d::SympNetLayer{M, M}, ps) where {M} N2 = size(x, 1)÷2 - q, p = assign_q_and_p(x, N2) - output = d((q=q, p=p), ps) + qp = assign_q_and_p(x, N2) + output = d(qp, ps) return vcat(output.q, output.p) end