-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First draft GPS conv layer #355
base: master
Are you sure you want to change the base?
Conversation
init = glorot_uniform, | ||
bias::Bool = true, | ||
add_self_loops = true, | ||
use_edge_weight = false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this PR is doing a lot of unrelated changes, please remove them
gattn = DotProductAttention(ch) | ||
convlayer = GNNChain(gconv, Dropout(0.5), LayerNorm(out)) | ||
attnlayer = GNNChain(gattn, Dropout(0.5), LayerNorm(out)) | ||
ffn = GNNChain(Dense(in => 2 * in, σ), Dropout(0.5), Dense(2 * in => in), Dropout(0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the paper specify that we need dropout in the MLP?
Parallel(+, l.attnlayer, identity), | ||
), | ||
l.ffn | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid constructing a gnnchain in each forward pass. Also, the use of parallel makes the sequence of operations hard to parse.
We should just have a plain sequence of transformations of the input here, like:
x1 = l.bn1(l.dropout1(l.conv(g, x)) .+ x))
x2 =l.bn2(l.dropout2(l.attn(x)) .+ x))
y = l.mlp( x1 + x2) # not sure if we should also add a skip connection here
Notice the order of operations. It is different from what the layer is currently doing. Current implementation doesn't seem to follow eqs. (6)-(11) in the paper.
Thanks, this is a nice start. A few comments:
|
Thanks for the comments. I'll be going over the i.e.
|
This is only first "mock" version of a GPSConv layer to see if we would want it in the Repo in that form.
Adds a
DotProductAttention
layer that usesNNlib.dot_product_attention()
Adds a
GPSConv
layerDotPRoductAttention
as global attention layerNot sure about the
GNNChain()
implementation, if it should stay where it is or move into the struct?JuliaFormatter() got a bit too greedy and made some changes here and there, I can revert those of course
Did not check for correctness of the implementation yet
Let me know what you think and I can adjust / keep going from here.
Close #351