Skip to content

Commit

Permalink
resgated hetero (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Mar 4, 2024
1 parent 884b473 commit 5a6bb6a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -863,18 +863,19 @@ function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
return ResGatedGraphConv(A, B, U, V, b, σ)
end

function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
function (l::ResGatedGraphConv)(g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)

message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx

Ax = l.A * x
Bx = l.B * x
Vx = l.V * x
Ax = l.A * xi
Bx = l.B * xj
Vx = l.V * xj

m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx))

return l.σ.(l.U * x .+ m .+ l.bias)
return l.σ.(l.U * xi .+ m .+ l.bias)
end

function Base.show(io::IO, l::ResGatedGraphConv)
Expand Down
8 changes: 8 additions & 0 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,12 @@
y = layers(hg, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end

@testset "ResGatedGraphConv" begin
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2),
(:B, :to, :A) => ResGatedGraphConv(4 => 2));
y = layers(hg, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end
end

0 comments on commit 5a6bb6a

Please sign in to comment.