-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unbatch
functionality
#4628
Conversation
Unbatching data from DataLoader batch to a list.
Unbatching data from DataLoader batch to a list.
Unbatching data from DataLoader batch to a list. This is useful for GNN edge classifier, in which case, graphs are required to be reconstructed with edge predictions.
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #4628 +/- ##
=======================================
Coverage 82.97% 82.98%
=======================================
Files 316 317 +1
Lines 16784 16792 +8
=======================================
+ Hits 13927 13935 +8
Misses 2857 2857
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in order to get this merged, we would need to have a pure PyTorch function without for-loops. IMO, this is achievable by
src.split(degree(index).tolist())
In addition, we would need to add this function to utils
rather than loader/utils.py
, add doccumentation, and write tests for it. WDYT?
For unbatching node features, for example, use ```src=data.x``` and ```index =data.x_batch``` (assume ```follow_batch``` is set to ```x```).
for more information, see https://pre-commit.ci
@rusty1s Thank you very much! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better. Thank you! Can we also add some basic tests in test/utils/test_unbatch.py
and add this operator to torch_geometric/utils/__init__.py
?
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
unbatch
functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
Unbatching data from DataLoader batch to a list. This is useful for GNN edge classifier, in which case, graphs are required to be reconstructed with edge predictions.