From 5a72e579116b8b66f5e74824513972191161408d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 16 Nov 2022 10:55:49 +0100 Subject: [PATCH 1/2] contrete type parameters; remove closures --- Project.toml | 1 + src/GraphNeuralNetworks.jl | 5 + src/layers/conv.jl | 255 +++++++++++++++++++------------------ src/msgpass.jl | 75 ++++++++--- test/layers/conv.jl | 3 +- 5 files changed, 196 insertions(+), 143 deletions(-) diff --git a/Project.toml b/Project.toml index 6643f2eea..82a370ef0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index a7d5af45e..c3a122538 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -12,6 +12,11 @@ using NNlib: scatter, gather using ChainRulesCore using Reexport +# use `@closure` in conv layers in ored to avoid issues +# https://github.com/JuliaLang/julia/issues/15276 +# https://github.com/FluxML/Zygote.jl/issues/1317 +# using FastClosures: @closure + using SparseArrays, Graphs # not needed but if removed Documenter will complain include("GNNGraphs/GNNGraphs.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl index facd71785..8fa290acd 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -53,8 +53,8 @@ l = GCNConv(3 => 5, use_edge_weight=true) y = l(g, x) # same as l(g, x, w) ``` """ -struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer - weight::A +struct GCNConv{W<:AbstractMatrix, B, F} <: GNNLayer + weight::W bias::B σ::F add_self_loops::Bool @@ -159,8 +159,8 @@ with ``\hat{L}`` the [`scaled_laplacian`](@ref). - `bias`: Add learnable bias. - `init`: Weights' initializer. """ -struct ChebConv{A<:AbstractArray{<:Number,3}, B} <: GNNLayer - weight::A +struct ChebConv{W<:AbstractArray{<:Number,3}, B} <: GNNLayer + weight::W bias::B k::Int end @@ -221,12 +221,12 @@ where the aggregation type is selected by `aggr`. - `bias`: Add learnable bias. - `init`: Weights' initializer. """ -struct GraphConv{A<:AbstractMatrix, B} <: GNNLayer - weight1::A - weight2::A +struct GraphConv{W<:AbstractMatrix,B,F,A} <: GNNLayer + weight1::W + weight2::W bias::B - σ - aggr + σ::F + aggr::A end @functor GraphConv @@ -291,12 +291,12 @@ and the attention coefficients will be calculated as - `negative_slope`: The parameter of LeakyReLU.Default `0.2`. - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. """ -struct GATConv{DX<:Dense,DE<:Union{Dense,Nothing}, T, A<:AbstractMatrix, B} <: GNNLayer +struct GATConv{DX<:Dense,DE<:Union{Dense,Nothing},T,A<:AbstractMatrix,F,B} <: GNNLayer dense_x::DX dense_e::DE bias::B a::A - σ + σ::F negative_slope::T channel::Pair{NTuple{2,Int}, Int} heads::Int @@ -343,20 +343,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM Wx = l.dense_x(x) Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes - function message(Wxi, Wxj, e) - if e === nothing - Wxx = vcat(Wxi, Wxj) - else - We = l.dense_e(e) - We = reshape(We, chout, heads, :) # chout × nheads × nnodes - Wxx = vcat(Wxi, Wxj, We) - end - aWW = sum(l.a .* Wxx, dims=1) # 1 × nheads × nedges - α = exp.(leakyrelu.(aWW, l.negative_slope)) - return (α = α, β = α .* Wxj) - end - - m = propagate(message, g, +; xi=Wx, xj=Wx, e) ## chout × nheads × nnodes + m = propagate(message, g, +, l; xi=Wx, xj=Wx, e) ## chout × nheads × nnodes x = m.β ./ m.α if !l.concat @@ -368,6 +355,22 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM return x end +function message(l::GATConv, Wxi, Wxj, e) + _, chout = l.channel + heads = l.heads + + if e === nothing + Wxx = vcat(Wxi, Wxj) + else + We = l.dense_e(e) + We = reshape(We, chout, heads, :) # chout × nheads × nnodes + Wxx = vcat(Wxi, Wxj, We) + end + aWW = sum(l.a .* Wxx, dims=1) # 1 × nheads × nedges + α = exp.(leakyrelu.(aWW, l.negative_slope)) + return (α = α, β = α .* Wxj) +end + function Base.show(io::IO, l::GATConv) (in, ein), out = l.channel print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) @@ -412,13 +415,13 @@ and the attention coefficients will be calculated as - `negative_slope`: The parameter of LeakyReLU.Default `0.2`. - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`. """ -struct GATv2Conv{T, A1, A2, A3, B, C<:AbstractMatrix} <: GNNLayer +struct GATv2Conv{T,A1,A2,A3,B,C<:AbstractMatrix,F} <: GNNLayer dense_i::A1 dense_j::A2 dense_e::A3 bias::B a::C - σ + σ::F negative_slope::T channel::Pair{NTuple{2,Int},Int} heads::Int @@ -477,18 +480,7 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra Wix = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes Wjx = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes - - function message(Wix, Wjx, e) - Wx = Wix + Wjx # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" - if e !== nothing - Wx += reshape(l.dense_e(e), out, heads, :) - end - eij = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims=1) # 1 × heads × nedges - α = exp.(eij) - return (α = α, β = α .* Wjx) - end - - m = propagate(message, g, +; xi=Wix, xj=Wjx, e) # out × heads × nnodes + m = propagate(message, g, +, l; xi=Wix, xj=Wjx, e) # out × heads × nnodes x = m.β ./ m.α if !l.concat @@ -499,6 +491,18 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra return x end +function message(l::GATv2Conv, Wix, Wjx, e) + _, out = l.channel + heads = l.heads + + Wx = Wix + Wjx # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" + if e !== nothing + Wx += reshape(l.dense_e(e), out, heads, :) + end + eij = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims=1) # 1 × heads × nedges + α = exp.(eij) + return (α = α, β = α .* Wjx) +end function Base.show(io::IO, l::GATv2Conv) (in, ein), out = l.channel @@ -531,12 +535,12 @@ where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing throu - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). - `init`: Weight initialization function. """ -struct GatedGraphConv{A<:AbstractArray{<:Number,3}, R} <: GNNLayer - weight::A +struct GatedGraphConv{W<:AbstractArray{<:Number,3},R,A} <: GNNLayer + weight::W gru::R out_ch::Int num_layers::Int - aggr + aggr::A end @functor GatedGraphConv @@ -591,9 +595,9 @@ where `nn` generally denotes a learnable function, e.g. a linear layer or a mult - `nn`: A (possibly learnable) function. - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). """ -struct EdgeConv <: GNNLayer - nn - aggr +struct EdgeConv{NN,A} <: GNNLayer + nn::NN + aggr::A end @functor EdgeConv @@ -602,8 +606,8 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr) function (l::EdgeConv)(g::GNNGraph, x::AbstractMatrix) check_num_nodes(g, x) - message(xi, xj, e) = l.nn(vcat(xi, xj .- xi)) - x = propagate(message, g, l.aggr, xi=x, xj=x) + message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) + x = propagate(message, g, l.aggr, l, xi=x, xj=x) return x end @@ -630,10 +634,10 @@ where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer o - `f`: A (possibly learnable) function acting on node features. - `ϵ`: Weighting factor. """ -struct GINConv{R<:Real} <: GNNLayer - nn +struct GINConv{R<:Real,NN,A} <: GNNLayer + nn::NN ϵ::R - aggr + aggr::A end @functor GINConv @@ -683,12 +687,12 @@ For convenience, also functions returning a single `(out*in, num_edges)` matrix - `bias`: Add learnable bias. - `init`: Weights' initializer. """ -struct NNConv <: GNNLayer - weight - bias - nn - σ - aggr +struct NNConv{W,B,NN,F,A} <: GNNLayer + weight::W + bias::B + nn::NN + σ::F + aggr::A end @functor NNConv @@ -701,20 +705,20 @@ function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glor end function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e) - check_num_nodes(g, x) + check_num_nodes(g, x) - function message(xi, xj, e) - nin, nedges = size(xj) - W = reshape(l.nn(e), (:, nin, nedges)) - xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul - m = NNlib.batched_mul(W, xj) - return reshape(m, :, nedges) - end - - m = propagate(message, g, l.aggr, xj=x, e=e) + m = propagate(message, g, l.aggr, l, xj=x, e=e) return l.σ.(l.weight*x .+ m .+ l.bias) end +function message(l::NNConv, xi, xj, e) + nin, nedges = size(xj) + W = reshape(l.nn(e), (:, nin, nedges)) + xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul + m = NNlib.batched_mul(W, xj) + return reshape(m, :, nedges) +end + (l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g))) function Base.show(io::IO, l::NNConv) @@ -746,11 +750,11 @@ where the aggregation type is selected by `aggr`. - `bias`: Add learnable bias. - `init`: Weights' initializer. """ -struct SAGEConv{A<:AbstractMatrix, B} <: GNNLayer - weight::A +struct SAGEConv{W<:AbstractMatrix, B, F, A} <: GNNLayer + weight::W bias::B - σ - aggr + σ::F + aggr::A end @functor SAGEConv @@ -803,13 +807,13 @@ where the edge gates ``\eta_{ij}`` are given by - `init`: Weight matrices' initializing function. - `bias`: Learn an additive bias if true. """ -struct ResGatedGraphConv <: GNNLayer - A - B - U - V - bias - σ +struct ResGatedGraphConv{W,B,F} <: GNNLayer + A::W + B::W + U::W + V::W + bias::B + σ::F end @functor ResGatedGraphConv @@ -893,10 +897,10 @@ l = CGConv(2 => 4, tanh) y = l(g, x) # size: (4, num_nodes) ``` """ -struct CGConv <: GNNLayer +struct CGConv{D1,D2} <: GNNLayer ch::Pair{NTuple{2,Int},Int} - dense_f::Dense - dense_s::Dense + dense_f::D1 + dense_s::D2 residual::Bool end @@ -917,16 +921,7 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractM check_num_edges(g, e) end - function message(xi, xj, e) - if e !== nothing - z = vcat(xi, xj, e) - else - z = vcat(xi, xj) - end - return l.dense_f(z) .* l.dense_s(z) - end - - m = propagate(message, g, +, xi=x, xj=x, e=e) + m = propagate(message, g, +, l, xi=x, xj=x, e=e) if l.residual if size(x, 1) == size(m, 1) @@ -939,6 +934,15 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, AbstractM return m end +function message(l::CGConv, xi, xj, e) + if e !== nothing + z = vcat(xi, xj, e) + else + z = vcat(xi, xj) + end + return l.dense_f(z) .* l.dense_s(z) +end + (l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g), edge_features(g))) function Base.show(io::IO, l::CGConv) @@ -1031,10 +1035,10 @@ m = MEGNetConv(3 => 3) x′, e′ = m(g, x, e) ``` """ -struct MEGNetConv <: GNNLayer - ϕe - ϕv - aggr +struct MEGNetConv{TE,TV,A} <: GNNLayer + ϕe::TE + ϕv::TV + aggr::A end @functor MEGNetConv @@ -1309,8 +1313,8 @@ end @doc raw""" - EdgeConv((in, ein) => out, hidden_size) - EdgeConv(in => out, hidden_size=2*in) + EdgeConv((in, ein) => out; hidden_size=2in, residual=false) + EdgeConv(in => out; hidden_size=2in, residual=false) Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph Neural Networks](https://arxiv.org/abs/2102.09844). @@ -1362,20 +1366,20 @@ egnn = EGNNConv(5 => 6, 10) hnew, xnew = egnn(g, h, x) ``` """ -struct EGNNConv <: GNNLayer - ϕe::Chain - ϕx::Chain - ϕh::Chain - num_features::NamedTuple +struct EGNNConv{TE,TX,TH,NF} <: GNNLayer + ϕe::TE + ϕx::TX + ϕh::TH + num_features::NF residual::Bool end @functor EGNNConv -EGNNConv(ch::Pair{Int,Int}, hidden_size=2*ch[1]) = EGNNConv((ch[1], 0) => ch[2], hidden_size) +EGNNConv(ch::Pair{Int,Int}, hidden_size=2*ch[1]; residual=false) = EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) #Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py -function EGNNConv(ch::Pair{NTuple{2, Int}, Int}, hidden_size::Int, residual=false) +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int=2*ch[1][1], residual=false) (in_size, edge_feat_size), out_size = ch act_fn = swish @@ -1389,7 +1393,7 @@ function EGNNConv(ch::Pair{NTuple{2, Int}, Int}, hidden_size::Int, residual=fals ϕx = Chain(Dense(hidden_size, hidden_size, swish), Dense(hidden_size, 1, bias=false)) - num_features = (in=in_size, edge=edge_feat_size, out=out_size) + num_features = (in=in_size, edge=edge_feat_size, out=out_size, hidden=hidden_size) if residual @assert in_size == out_size "Residual connection only possible if in_size == out_size" end @@ -1402,26 +1406,12 @@ function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e=noth end @assert size(h, 1) == l.num_features.in "Input features must match layer input size." - - @show size(x) size(h) - - function message(xi, xj, e) - if l.num_features.edge > 0 - f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) - else - f = vcat(xi.h, xj.h, e.sqnorm_xdiff) - end - - msg_h = l.ϕe(f) - msg_x = l.ϕx(msg_h) .* e.x_diff - return (; x=msg_x, h=msg_h) - end - x_diff = apply_edges(xi_sub_xj, g, x, x) sqnorm_xdiff = sum(x_diff .^ 2, dims=1) x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1f-6) - msg = apply_edges(message, g, xi=(; h), xj=(; h), e=(; e, x_diff, sqnorm_xdiff)) + msg = apply_edges(message, g, l, + xi=(; h), xj=(; h), e=(; e, x_diff, sqnorm_xdiff)) h_aggr = aggregate_neighbors(g, +, msg.h) x_aggr = aggregate_neighbors(g, mean, msg.x) @@ -1432,6 +1422,29 @@ function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e=noth h = hnew end x = x .+ x_aggr - return h, x end + +function message(l::EGNNConv, xi, xj, e) + if l.num_features.edge > 0 + f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) + else + f = vcat(xi.h, xj.h, e.sqnorm_xdiff) + end + + msg_h = l.ϕe(f) + msg_x = l.ϕx(msg_h) .* e.x_diff + return (; x=msg_x, h=msg_h) +end + +function Base.show(io::IO, l::EGNNConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + nh = l.num_features.hidden + print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") + if l.residual + print(io, ", residual=true") + end + print(io, ")") +end diff --git a/src/msgpass.jl b/src/msgpass.jl index 61af0c15b..7359bc3a3 100644 --- a/src/msgpass.jl +++ b/src/msgpass.jl @@ -1,16 +1,19 @@ """ - propagate(f, g, aggr; [xi, xj, e]) -> m̄ - propagate(f, g, aggr, xi, xj, e=nothing) + propagate(fmsg, g, aggr [layer]; [xi, xj, e]) + propagate(fmsg, g, aggr, [layer,] xi, xj, e=nothing) Performs message passing on graph `g`. Takes care of materializing the node features on each edge, -applying the message function, and returning an aggregated message ``\\bar{\\mathbf{m}}`` -(depending on the return value of `f`, an array or a named tuple of +applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` +(depending on the return value of `fmsg`, an array or a named tuple of arrays with last dimension's size `g.num_nodes`). +If also a [GNNLayer](@ref) `layer` is provided, it will be passed to `fmsg` +as a first argument. + It can be decomposed in two steps: ```julia -m = apply_edges(f, g, xi, xj, e) +m = apply_edges(fmsg, g, xi, xj, e) m̄ = aggregate_neighbors(g, aggr, m) ``` @@ -25,13 +28,17 @@ providing as input `f` a closure. target node of each edge (see also [`edge_index`](@ref)). - `xj`: As `xj`, but to be materialized on edges' sources. - `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `f`: A generic function that will be passed over to [`apply_edges`](@ref). +- `fmsg`: A generic function that will be passed over to [`apply_edges`](@ref). Has to take as inputs the edge-materialized `xi`, `xj`, and `e` (arrays or named tuples of arrays whose last dimension' size is the size of a batch of edges). Its output has to be an array or a named tuple of arrays - with the same batch size. + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. +- `layer`: A [GNNLayer](@ref). If provided it will be passed to `fmsg` as a first argument. - `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. + # Examples ```julia @@ -66,30 +73,44 @@ See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). """ function propagate end -propagate(l, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) = - propagate(l, g, aggr, xi, xj, e) +propagate(f, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) = + propagate(f, g, aggr, xi, xj, e) -function propagate(l, g::GNNGraph, aggr, xi, xj, e=nothing) - m = apply_edges(l, g, xi, xj, e) +function propagate(f, g::GNNGraph, aggr, xi, xj, e=nothing) + m = apply_edges(f, g, xi, xj, e) m̄ = aggregate_neighbors(g, aggr, m) return m̄ end + +## convenience methods for working around performance issues +# https://github.com/JuliaLang/julia/issues/15276 +## and zygote issues +# https://github.com/FluxML/Zygote.jl/issues/1317 +propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi=nothing, xj=nothing, e=nothing) = + propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) +propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e=nothing) = + propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) + ## APPLY EDGES """ - apply_edges(f, g; [xi, xj, e]) - apply_edges(f, g, xi, xj, e=nothing) + apply_edges(fmsg, g, [layer]; [xi, xj, e]) + apply_edges(fmsg, g, [layer,] xi, xj, e=nothing) -Returns the message from node `j` to node `i` . +Returns the message from node `j` to node `i` applying +the message function `fmsg` on the edges in graph `g`. In the message-passing scheme, the incoming messages from the neighborhood of `i` will later be aggregated -in order to update the features of node `i`. +in order to update the features of node `i` (see [`aggregate_neighbors`](@ref)). -The function operates on batches of edges, therefore +The function `fmsg` operates on batches of edges, therefore `xi`, `xj`, and `e` are tensors whose last dimension is the batch size, or can be named tuples of such tensors. + +If also a [GNNLayer](@ref) `layer` is provided, it will be passed to `fmsg` +as a first argument. # Arguments @@ -99,17 +120,21 @@ such tensors. target node of each edge (see also [`edge_index`](@ref)). - `xj`: As `xi`, but now to be materialized on each edge's source node. - `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `f`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. +- `fmsg`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. These are arrays (or named tuples of arrays) whose last dimension' size is the size of a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) - with the same batch size. + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. +- `layer`: A [GNNLayer](@ref). If provided it will be passed to `fmsg` as a first argument. See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). """ function apply_edges end -apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) = - apply_edges(l, g, xi, xj, e) +apply_edges(f, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) = + apply_edges(f, g, xi, xj, e) + function apply_edges(f, g::GNNGraph, xi, xj, e=nothing) check_num_nodes(g, xi) @@ -122,6 +147,16 @@ function apply_edges(f, g::GNNGraph, xi, xj, e=nothing) return m end +## convenience methods for working around performance issues +# https://github.com/JuliaLang/julia/issues/15276 +## and zygote issues +# https://github.com/FluxML/Zygote.jl/issues/1317 +apply_edges(f, g::GNNGraph, l::GNNLayer; xi=nothing, xj=nothing, e=nothing) = + apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) + +apply_edges(f, g::GNNGraph, l::GNNLayer, xi, xj, e=nothing) = + apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) + ## AGGREGATE NEIGHBORS @doc raw""" aggregate_neighbors(g::GNNGraph, aggr, m) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 6d8687a84..5ebf3d810 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -207,8 +207,7 @@ l = NNConv(in_channel => out_channel, nn, tanh, bias=true, aggr=+) for g in test_graphs g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) - # FIXME issue #208 - # test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes)) end end From 55e6ed85c8a8d4e072c26597d2c90eeeb89206e4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 16 Nov 2022 12:00:17 +0100 Subject: [PATCH 2/2] cleanup --- Project.toml | 1 - src/GraphNeuralNetworks.jl | 6 ------ 2 files changed, 7 deletions(-) diff --git a/Project.toml b/Project.toml index 82a370ef0..6643f2eea 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index c3a122538..db8e7be16 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -11,12 +11,6 @@ using NNlib, NNlibCUDA using NNlib: scatter, gather using ChainRulesCore using Reexport - -# use `@closure` in conv layers in ored to avoid issues -# https://github.com/JuliaLang/julia/issues/15276 -# https://github.com/FluxML/Zygote.jl/issues/1317 -# using FastClosures: @closure - using SparseArrays, Graphs # not needed but if removed Documenter will complain include("GNNGraphs/GNNGraphs.jl")