diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 09efea74e..37abd7006 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -863,18 +863,19 @@ function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; return ResGatedGraphConv(A, B, U, V, b, σ) end -function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix) +function (l::ResGatedGraphConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx - Ax = l.A * x - Bx = l.B * x - Vx = l.V * x + Ax = l.A * xi + Bx = l.B * xj + Vx = l.V * xj m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx)) - return l.σ.(l.U * x .+ m .+ l.bias) + return l.σ.(l.U * xi .+ m .+ l.bias) end function Base.show(io::IO, l::ResGatedGraphConv) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index e4d0fd40a..40baf4430 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -116,4 +116,12 @@ y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + + @testset "ResGatedGraphConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2), + (:B, :to, :A) => ResGatedGraphConv(4 => 2)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end