Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Graph embedding and graph backbones #592

Merged
merged 205 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
205 commits
Select commit Hold shift + click to select a range
ea4e7b6
Initial structure of GraphClassification model.py
Feb 2, 2021
1255f8f
Improvement of model.py. Still need to debug etc
Feb 2, 2021
365863f
BasicDataset Implemented
Feb 6, 2021
02b0f6a
Create __init__.py
Feb 6, 2021
f28e949
Implemented dataset and DataModule as for image processing
Feb 7, 2021
ad76827
Pipeline taken from images.
Feb 8, 2021
ea6ee9d
Initial structure of GraphClassification model.py
Feb 2, 2021
8b93a4a
Improvement of model.py. Still need to debug etc
Feb 2, 2021
6b4d7e3
BasicDataset Implemented
Feb 6, 2021
48dcf2d
Implemented dataset and DataModule as for image processing
Feb 7, 2021
49dfe4d
Pipeline taken from images.
Feb 8, 2021
151f7d9
Choice of model implemented (you can pass a model to GraphClassifier)
Mar 21, 2021
6236d95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
28e315f
Initial readaptation of the structure
May 14, 2021
08ace7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
35a79cb
Minimal structure of how to structure data.py files
May 14, 2021
93dd638
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
82576a5
Minor corrections
May 17, 2021
089cb07
update
tchaton May 17, 2021
920fc68
i
tchaton May 17, 2021
c970b5f
update
tchaton May 17, 2021
fe41405
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2021
868b8d7
Added auto_dataset.num_features
May 24, 2021
debf5da
Deleted manually included num_features so that it is extracted from G…
May 24, 2021
1e6b2b0
Test for GraphClassification implemented
May 30, 2021
faa8709
Documentation for GraphClassification included
May 30, 2021
072a35b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2021
be015df
Creation of from_pygdatasequence method in DataModule and GraphSequen…
Jun 8, 2021
1fb160b
Update graph_classification.py
Jun 7, 2021
bb3b941
Update datatype_graph.txt
Jun 7, 2021
3583d5c
Tests and docs for the from_pygdatasequence method
Jun 8, 2021
193c2bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
f59a2fc
Graph requirements
Jun 8, 2021
71a15bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
370ea24
Update CHANGELOG.md
Jun 8, 2021
e9d4e93
Update requirements with pytorch geometric libraries
Jun 8, 2021
a2b208e
Simplified, version with only the DataSource
Jul 12, 2021
7c3eaf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
809b615
Minor tweaks
Jul 12, 2021
bf4a9b6
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
Jul 12, 2021
3089d94
Update the flash_example to reflect the new template
Jul 12, 2021
338c3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
0f39778
Delete IMDB-BINARY_A.txt
Jul 12, 2021
b051f36
Delete IMDB-BINARY_graph_indicator.txt
Jul 12, 2021
4e23aff
Delete IMDB-BINARY_graph_labels.txt
Jul 12, 2021
1ad2ce3
Class method from_pygdatasequence from flash/core/data/data_module.py
Jul 12, 2021
4e9178c
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
Jul 12, 2021
2b1bbd9
Creating backbones
Jul 14, 2021
94d4e60
Merge branch 'master' into master
ethanwharris Jul 14, 2021
c2cac21
Changing GRAPH_BACKBONES to GRAPH_CLASSIFICATION_BACKBONES
Jul 14, 2021
73663bd
Update docs
ethanwharris Jul 14, 2021
287132f
Merge branch 'master' into master
ethanwharris Jul 14, 2021
2631bd4
fix imports.py
ethanwharris Jul 14, 2021
b4d3b41
remove unused imports
ethanwharris Jul 14, 2021
b19b5b8
clean init.py
ethanwharris Jul 14, 2021
7a4a914
updates
ethanwharris Jul 14, 2021
c831751
Minor tweaks in the docs and change from Graph_backbones to graph_cla…
Jul 14, 2021
13aa012
Updates
ethanwharris Jul 14, 2021
e9cedb0
Updates
ethanwharris Jul 14, 2021
fe95a77
Updates
ethanwharris Jul 14, 2021
54f6a88
Graph embedding task implemented, modulo corrections
Jul 14, 2021
667140a
Error corrections
Jul 14, 2021
db0b599
Merge branch 'master' into task_a_thon
Jul 14, 2021
bbdad91
Updates
ethanwharris Jul 14, 2021
d5deb38
Update docs
ethanwharris Jul 14, 2021
f634e9f
Update docs
ethanwharris Jul 14, 2021
428f313
Update docs
ethanwharris Jul 14, 2021
435cc95
fix tests
ethanwharris Jul 14, 2021
b54e543
fix tests
ethanwharris Jul 14, 2021
4453818
Add API reference
ethanwharris Jul 14, 2021
b4877b1
Try fix
ethanwharris Jul 14, 2021
aceef22
Merge branch 'master' into task_a_thon
Jul 14, 2021
c6e3b84
Included Networkx as requirement for graph library
Jul 14, 2021
8fe2813
Try fix
ethanwharris Jul 14, 2021
113b6d0
Try fix
ethanwharris Jul 14, 2021
bba0395
Merge branch 'master' into task_a_thon
Jul 14, 2021
d6dec3d
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])…
Jul 14, 2021
d8d26ab
Update flash/core/data/auto_dataset.py
ethanwharris Jul 14, 2021
7b2734f
Update docstring
ethanwharris Jul 14, 2021
a09254f
Correction of minor errors
Jul 14, 2021
62ff1b3
Updating docs
Jul 14, 2021
97b5e07
Merge branch 'master' into task_a_thon
Jul 14, 2021
8d9db4c
Update graph_embedding docs
Jul 14, 2021
2021a03
Minor tweaks
Jul 14, 2021
b064be3
Merge remote-tracking branch 'upstream/master' into task_a_thon
Jul 15, 2021
31354fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
b1fc948
Correct documentation
Jul 15, 2021
a1e5bb7
Creating a head suited for PyG backbones
Jul 15, 2021
bc06e43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
335d30c
Update the head of the embedding model
Jul 15, 2021
56598f7
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Jul 15, 2021
7de0a50
Update graph_classification example
Jul 15, 2021
628ef9b
Update backbones to match how they will work in Pytorch Geometric
Jul 16, 2021
bed29f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
e49c290
Update CHANGELOG.md
Jul 17, 2021
95abca4
Merge remote-tracking branch 'upstream/master' into task_a_thon
Jul 17, 2021
0903585
Update graph requirements
Jul 17, 2021
cf45529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
a872b75
Update test_model.py
Jul 17, 2021
88812b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
360250b
Update data_module.py
Jul 17, 2021
41db022
Update data_module.py
Jul 17, 2021
077b34b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
5d3e1b4
Update test_model.py
Jul 17, 2021
70d0570
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Jul 17, 2021
9a04c2f
_TORCH_GEOMETRIC_AVAILABLE to _GRAPH_AVAILABLE
Jul 17, 2021
7d63be0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
ac4616a
Adding num_features to uses of backbones
Jul 17, 2021
53395e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
f902c5f
Update graph_classification.py
Jul 17, 2021
c9c3c0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
093a266
Update test_model.py
Jul 17, 2021
c1b5cab
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Jul 17, 2021
f2962b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
89a8e2d
Backbone kwargs default changed from None to {}
Jul 17, 2021
56e9c2e
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Jul 17, 2021
9012d61
error correction in graph/embedding/test_model.py: num_classes -> emb…
Jul 17, 2021
76c6d55
Pretrained option not implemented for backbones error
Jul 17, 2021
8be004b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2021
2bdf6a6
Update model.py
Jul 17, 2021
9ecff7e
Small error correction
Jul 17, 2021
cc1ee34
Merge branch 'master' into task_a_thon
ethanwharris Jul 21, 2021
63577b9
Updates
ethanwharris Jul 21, 2021
62802a0
models adapted to torch geometric basic models
Jul 21, 2021
7590175
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Jul 21, 2021
92d6b3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2021
3428fd6
Merge remote-tracking branch 'upstream/master' into task_a_thon
Sep 20, 2021
5b72219
Merge branch 'master' into task_a_thon
Sep 20, 2021
c467e60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
ff5a592
Updating minor corrections
Sep 20, 2021
de9c45a
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 20, 2021
c878c11
Eliminating needless stuff
Sep 20, 2021
667f5d0
Changed descriptions
Sep 20, 2021
5a7060d
Update CHANGELOG.md
Sep 23, 2021
ce168a7
Update flash/graph/embedding/__init__.py
Sep 23, 2021
d9d9298
Update flash_examples/graph_classification.py
Sep 23, 2021
6816d76
Update docs/source/reference/graph_embedding.rst
Sep 23, 2021
caad522
Updates based on ethan comments
Sep 23, 2021
fe92e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
472e00d
Importing partial
Sep 23, 2021
4b1923f
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
1c76bd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
2de9f26
Update backbones.py
Sep 23, 2021
de9577a
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
4b8f4b5
Update backbones.py
Sep 23, 2021
11bd028
Update model.py
Sep 23, 2021
c40ad61
Eliminating pretrained option
Sep 23, 2021
3453e2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
85dc93e
Update backbones.py
Sep 23, 2021
e4de58d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
6d631e6
Update backbones.py
Sep 23, 2021
e74f17b
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
b886e96
Update backbones.py
Sep 23, 2021
38196c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
90cdfa8
Update backbones.py
Sep 23, 2021
ee7cb69
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
14cd39c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
4199aab
Update backbones.py
Sep 23, 2021
9a2d046
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
4858083
Merge branch 'master' into task_a_thon
Sep 23, 2021
80e4e3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
65e499a
Update backbones.py
Sep 23, 2021
54f54d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
18d78d7
Update backbones.py
Sep 23, 2021
a69e307
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
ac9bb8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
6658640
Update backbones.py
Sep 23, 2021
11b07f9
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
89e7040
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
18c69b6
Update backbones.py
Sep 23, 2021
e980c25
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
8f9217d
Update backbones.py
Sep 23, 2021
0e9968a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
a440d47
Update backbones.py
Sep 23, 2021
e6e84df
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 23, 2021
58b9348
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
fdbc4db
Minor correction in model.py
Sep 24, 2021
c708d04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2021
994dd30
self.backbone(x.x, x.edge_index)
Sep 24, 2021
57571f8
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
Sep 24, 2021
8b9d733
Update model.py
Sep 24, 2021
a10749c
Update model.py
Sep 24, 2021
21e8078
update
tchaton Sep 27, 2021
3aa0e95
Merge branch 'master' into task_a_thon
Oct 16, 2021
6bab7b1
Merge branch 'master' into task_a_thon
ethanwharris Nov 9, 2021
d3148be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2021
80bb1e8
Updates
ethanwharris Nov 9, 2021
86ff38b
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
ethanwharris Nov 9, 2021
b0429ca
Update CHANGELOG.md
ethanwharris Nov 9, 2021
dc27f8a
Docs fixes
ethanwharris Nov 9, 2021
00b7159
Updates
ethanwharris Nov 9, 2021
da8fefb
Docs
ethanwharris Nov 9, 2021
2738e75
Tests
ethanwharris Nov 9, 2021
cdb8217
Pre-commit
ethanwharris Nov 9, 2021
3a67f66
Merge branch 'master' into task_a_thon
ethanwharris Nov 9, 2021
47dae8a
Update requirements
ethanwharris Nov 9, 2021
9ca6084
Formatting
ethanwharris Nov 9, 2021
acf92e2
Fix reqs.
ethanwharris Nov 9, 2021
bff6228
Merge branch 'master' into task_a_thon
ethanwharris Nov 9, 2021
8b490b0
Fixes
ethanwharris Nov 9, 2021
54d5de8
Multiple pooling functions
ethanwharris Nov 9, 2021
c59c766
Fixes
ethanwharris Nov 9, 2021
8ee6262
Fixes
ethanwharris Nov 9, 2021
dfa14e6
Try fix
ethanwharris Nov 9, 2021
f357f62
Speed up CI
ethanwharris Nov 9, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ jobs:
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-1.9.0+cpu.html

- name: Install dependencies
run: |
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

### Changed

- Changed `Preprocess` to `InputTransform` ([#951](https://github.com/PyTorchLightning/lightning-flash/pull/951))
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ ______________

classification.data.GraphClassificationInputTransform

Embedding
_________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~embedding.model.GraphEmbedder

flash.graph.data
________________

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Lightning Flash
:caption: Graph

reference/graph_classification
reference/graph_embedder

.. toctree::
:maxdepth: 1
Expand Down
28 changes: 28 additions & 0 deletions docs/source/reference/graph_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _graph_embedder:

##############
Graph Embedder
##############

********
The Task
********
This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task.
The :class:`~flash.graph.classification.model.GraphEmbedder` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_.

------

*******
Example
*******

Let's look at generating embeddings of graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_.

We start by creating the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`.
Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` (from a previously trained :class:`~flash.graph.classification.model.GraphClassifier`).
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/graph_embedder.py
:language: python
:lines: 14
5 changes: 4 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_NETWORKX_AVAILABLE = _module_available("networkx")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
Expand Down Expand Up @@ -143,7 +144,9 @@ class Image:
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE
_GRAPH_AVAILABLE = (
_TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE and _NETWORKX_AVAILABLE
)

_EXTRAS_AVAILABLE = {
"image": _IMAGE_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ def __str__(self):
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting")
_PYTORCH_GEOMETRIC = Provider("PyG/PyTorch Geometric", "https://github.com/pyg-team/pytorch_geometric")
1 change: 1 addition & 0 deletions flash/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401
from flash.graph.embedding import GraphEmbedder # noqa: F401
41 changes: 41 additions & 0 deletions flash/graph/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_GEOMETRIC

if _GRAPH_AVAILABLE:
from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we automate this and get all their models ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not possible to do cleanly at the moment as the models package also contains many other things that wouldn't work in the GraphClassifier


MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could integrate GraphGym directly there ?

else:
MODELS = {}

GRAPH_BACKBONES = FlashRegistry("backbones")


def _load_graph_backbone(
model_name: str,
in_channels: int,
hidden_channels: int = 512,
num_layers: int = 4,
):
model = MODELS[model_name]
return model(in_channels, hidden_channels, num_layers)


for model_name in MODELS.keys():
GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name))
167 changes: 61 additions & 106 deletions flash/graph/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,157 +11,90 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from torch import nn, Tensor
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear

from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.graph.backbones import GRAPH_BACKBONES

if _GRAPH_AVAILABLE:
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing
else:
MessagePassing = object
GCNConv = object


class GraphBlock(nn.Module):
"""Graph convolutional block.

Args:
nc_input: number of input channels
nc_output: number of output channels
conv_cls: graph convolutional class to use
act: activation function to use
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

def __init__(
self,
nc_input: int,
nc_output: int,
conv_cls: nn.Module,
act: Union[Callable, nn.Module] = nn.ReLU(),
**conv_kwargs
):
super().__init__()
self.conv = conv_cls(nc_input, nc_output, **conv_kwargs)
self.norm = BatchNorm(nc_output)
self.act = act
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool

def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
x = self.conv(x, edge_index, edge_weight=edge_weight)
x = self.norm(x)
return self.act(x)


class BaseGraphModel(nn.Module):
"""Base convolutional graph model.

Args:
num_features: number of input features
hidden_channels: list of integers with the number of channels in all the hidden layers.
The length of the list determines the depth of the network.
num_classes: integer determining the number of classes
conv_cls: graph convolutional class to use as building blocks
act: activation function to use between layers
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

def __init__(
self,
num_features: int,
hidden_channels: List[int],
num_classes: int,
conv_cls: Type[MessagePassing],
act: Union[Callable, nn.Module] = nn.ReLU(),
**conv_kwargs: Any
):
super().__init__()

self.blocks = nn.ModuleList()
hidden_channels = [num_features] + hidden_channels

nc_output = num_features

for idx in range(len(hidden_channels) - 1):
nc_input = hidden_channels[idx]
nc_output = hidden_channels[idx + 1]
graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs)
self.blocks.append(graph_block)

self.lin = Linear(nc_output, num_classes)

def forward(self, data: Any) -> Tensor:
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
# 1. Obtain node embeddings
for block in self.blocks:
x = block(x, edge_index, edge_weight)

# 2. Readout layer
x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels]

# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
POOLING_FUNCTIONS = {"mean": global_mean_pool, "add": global_add_pool, "max": global_max_pool}
else:
POOLING_FUNCTIONS = {}


class GraphClassifier(ClassificationTask):
"""The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see
:ref:`graph_classification`.

Args:
num_features: Number of columns in table (not including target column).
num_classes: Number of classes to classify.
hidden_channels: Hidden dimension sizes.
learning_rate: Learning rate to use for training, defaults to `1e-3`
num_features (int): The number of features in the input.
num_classes (int): Number of classes to classify.
backbone: Name of the backbone to use.
backbone_kwargs: Dictionary dependent on the backbone, containing for example in_channels, out_channels,
hidden_channels or depth (number of layers).
pooling_fn: The global pooling operation to use (one of: "max", "max", "add" or a callable).
head: The head to use.
loss_fn: Loss function for training, defaults to cross entropy.
learning_rate: Learning rate to use for training.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
metrics: Metrics to compute for training and evaluation.
model: GraphNN used, defaults to BaseGraphModel.
conv_cls: kind of convolution used in model, defaults to GCNConv
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

required_extras = "graph"
backbones: FlashRegistry = GRAPH_BACKBONES

required_extras: str = "graph"

def __init__(
self,
num_features: int,
num_classes: int,
hidden_channels: Union[List[int], int] = 512,
model: torch.nn.Module = None,
backbone: Union[str, Tuple[nn.Module, int]] = "GCN",
backbone_kwargs: Optional[Dict] = {},
pooling_fn: Optional[Union[str, Callable]] = "mean",
head: Optional[Union[Callable, nn.Module]] = None,
loss_fn: LOSS_FN_TYPE = F.cross_entropy,
learning_rate: float = 1e-3,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
conv_cls: Type[MessagePassing] = GCNConv,
**conv_kwargs
):

self.save_hyperparameters()

if isinstance(hidden_channels, int):
hidden_channels = [hidden_channels]

if not model:
model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
)

self.save_hyperparameters()

if isinstance(backbone, tuple):
self.backbone, num_out_features = backbone
else:
self.backbone = self.backbones.get(backbone)(in_channels=num_features, **backbone_kwargs)
num_out_features = self.backbone.hidden_channels

self.pooling_fn = POOLING_FUNCTIONS[pooling_fn] if isinstance(pooling_fn, str) else pooling_fn

if head is not None:
self.head = head
else:
self.head = DefaultGraphHead(num_out_features, num_classes)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().training_step(batch, batch_idx)
Expand All @@ -176,3 +109,25 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def forward(self, data) -> torch.Tensor:
x = self.backbone(data.x, data.edge_index)
x = self.pooling_fn(x, data.batch)
return self.head(x)


class DefaultGraphHead(torch.nn.Module):
def __init__(self, hidden_channels, num_classes, dropout=0.5):
super().__init__()
self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, num_classes)
self.dropout = dropout

def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()

def forward(self, x):
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin2(x)
2 changes: 2 additions & 0 deletions flash/graph/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.graph.classification.data import GraphClassificationData # noqa: F401
from flash.graph.embedding.model import GraphEmbedder # noqa: F401
Loading