Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft GPS conv layer #355

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

abieler
Copy link
Contributor

@abieler abieler commented Dec 28, 2023

This is only first "mock" version of a GPSConv layer to see if we would want it in the Repo in that form.

  • Adds a DotProductAttention layer that uses NNlib.dot_product_attention()

  • Adds a GPSConv layer

    • has the DotPRoductAttention as global attention layer
    • takes a conv-layer as local message passing
  • Not sure about the GNNChain() implementation, if it should stay where it is or move into the struct?

  • JuliaFormatter() got a bit too greedy and made some changes here and there, I can revert those of course

  • Did not check for correctness of the implementation yet

Let me know what you think and I can adjust / keep going from here.

Close #351

init = glorot_uniform,
bias::Bool = true,
add_self_loops = true,
use_edge_weight = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR is doing a lot of unrelated changes, please remove them

gattn = DotProductAttention(ch)
convlayer = GNNChain(gconv, Dropout(0.5), LayerNorm(out))
attnlayer = GNNChain(gattn, Dropout(0.5), LayerNorm(out))
ffn = GNNChain(Dense(in => 2 * in, σ), Dropout(0.5), Dense(2 * in => in), Dropout(0.5))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the paper specify that we need dropout in the MLP?

Parallel(+, l.attnlayer, identity),
),
l.ffn
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid constructing a gnnchain in each forward pass. Also, the use of parallel makes the sequence of operations hard to parse.
We should just have a plain sequence of transformations of the input here, like:

x1 = l.bn1(l.dropout1(l.conv(g, x)) .+ x))
x2 =l.bn2(l.dropout2(l.attn(x)) .+ x))
y = l.mlp( x1 + x2) # not sure if we should also add a skip connection here

Notice the order of operations. It is different from what the layer is currently doing. Current implementation doesn't seem to follow eqs. (6)-(11) in the paper.

@CarloLucibello
Copy link
Member

Thanks, this is a nice start. A few comments:

  • No need to introduce the DotProductAttention type, we can use the MultiHeadAttention from Flux. According to the table A.2-A.5 in the paper, multi-head attention is the preferred choice. We should have an nheads argument in the constructor. See also https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html for the kind of flexibility we could try to achieve (in several PRs).

  • Part of the contribution of the paper is the discussion of different types of embeddings. This package lacks many of these embeddings. I hope they will be added in the future but in any case, it is ok for this PR to only implement the layer.

  • I think the current order of operations is wrong, see comment First draft GPS conv layer #355 (comment)

  • BathcNorm should be used instead of LayerNorm

  • In the paper is not clear if we should apply a residual connection after the MLP. For figure D.1 it seems there is one, but there is none according to Eq. 11.

@abieler
Copy link
Contributor Author

abieler commented Jan 2, 2024

Thanks for the comments. I'll be going over the authors codebase > paper > pytorch implementation for implementation details for the next version

i.e.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GPSConv
2 participants