From 03e8414fb1380ce162828bfa7c0e04ead1833e9e Mon Sep 17 00:00:00 2001 From: Sachin Sharma <40523048+sachinsharma9780@users.noreply.github.com> Date: Thu, 13 Oct 2022 12:29:05 +0530 Subject: [PATCH 1/3] [Type Hints] `DenseSAGEConv` and `DenseGCNConv` (#5664) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey --- CHANGELOG.md | 2 +- test/nn/dense/test_dense_gcn_conv.py | 5 +++++ test/nn/dense/test_dense_sage_conv.py | 5 +++++ torch_geometric/nn/dense/dense_gcn_conv.py | 13 +++++++++++-- torch_geometric/nn/dense/dense_sage_conv.py | 16 +++++++++++++--- 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 343e24c880a1..938c7a258a4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) diff --git a/test/nn/dense/test_dense_gcn_conv.py b/test/nn/dense/test_dense_gcn_conv.py index af17cae33089..14f926548eea 100644 --- a/test/nn/dense/test_dense_gcn_conv.py +++ b/test/nn/dense/test_dense_gcn_conv.py @@ -1,6 +1,7 @@ import torch from torch_geometric.nn import DenseGCNConv, GCNConv +from torch_geometric.testing import is_full_test def test_dense_gcn_conv(): @@ -38,6 +39,10 @@ def test_dense_gcn_conv(): dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) + if is_full_test(): + jit = torch.jit.script(dense_conv) + assert torch.allclose(jit(x, adj, mask), dense_out) + assert dense_out[1, 2].abs().sum().item() == 0 dense_out = dense_out.view(6, channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-04) diff --git a/test/nn/dense/test_dense_sage_conv.py b/test/nn/dense/test_dense_sage_conv.py index f5055cea1b7f..12d2204e7990 100644 --- a/test/nn/dense/test_dense_sage_conv.py +++ b/test/nn/dense/test_dense_sage_conv.py @@ -1,6 +1,7 @@ import torch from torch_geometric.nn import DenseSAGEConv, SAGEConv +from torch_geometric.testing import is_full_test def test_dense_sage_conv(): @@ -38,6 +39,10 @@ def test_dense_sage_conv(): dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) + if is_full_test(): + jit = torch.jit.script(dense_conv) + assert torch.allclose(jit(x, adj, mask), dense_out) + assert dense_out[1, 2].abs().sum().item() == 0 dense_out = dense_out.view(6, channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-04) diff --git a/torch_geometric/nn/dense/dense_gcn_conv.py b/torch_geometric/nn/dense/dense_gcn_conv.py index f0809a371aba..1d5acec31f3f 100644 --- a/torch_geometric/nn/dense/dense_gcn_conv.py +++ b/torch_geometric/nn/dense/dense_gcn_conv.py @@ -1,14 +1,22 @@ import torch +from torch import Tensor from torch.nn import Parameter from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros +from torch_geometric.typing import OptTensor class DenseGCNConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GCNConv`. """ - def __init__(self, in_channels, out_channels, improved=False, bias=True): + def __init__( + self, + in_channels: int, + out_channels: int, + improved: bool = False, + bias: bool = True, + ): super().__init__() self.in_channels = in_channels @@ -29,7 +37,8 @@ def reset_parameters(self): self.lin.reset_parameters() zeros(self.bias) - def forward(self, x, adj, mask=None, add_loop=True): + def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None, + add_loop: bool = True) -> Tensor: r""" Args: x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B diff --git a/torch_geometric/nn/dense/dense_sage_conv.py b/torch_geometric/nn/dense/dense_sage_conv.py index 652e04e992cb..7243875d1192 100644 --- a/torch_geometric/nn/dense/dense_sage_conv.py +++ b/torch_geometric/nn/dense/dense_sage_conv.py @@ -1,7 +1,10 @@ import torch import torch.nn.functional as F +from torch import Tensor from torch.nn import Linear +from torch_geometric.typing import OptTensor + class DenseSAGEConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.SAGEConv`. @@ -14,7 +17,13 @@ class DenseSAGEConv(torch.nn.Module): use :class:`torch_geometric.nn.dense.DenseGraphConv` instead. """ - def __init__(self, in_channels, out_channels, normalize=False, bias=True): + def __init__( + self, + in_channels: int, + out_channels: int, + normalize: bool = False, + bias: bool = True, + ): super().__init__() self.in_channels = in_channels @@ -30,7 +39,8 @@ def reset_parameters(self): self.lin_rel.reset_parameters() self.lin_root.reset_parameters() - def forward(self, x, adj, mask=None): + def forward(self, x: Tensor, adj: Tensor, + mask: OptTensor = None) -> Tensor: r""" Args: x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B @@ -54,7 +64,7 @@ def forward(self, x, adj, mask=None): out = self.lin_rel(out) + self.lin_root(x) if self.normalize: - out = F.normalize(out, p=2, dim=-1) + out = F.normalize(out, p=2.0, dim=-1) if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) From 891eb51ae22668c588964487e56ee0cddf341525 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 13 Oct 2022 09:09:01 +0200 Subject: [PATCH 2/3] fix test --- test/loader/test_link_neighbor_loader.py | 2 ++ test/loader/test_neighbor_loader.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 59905f33edcd..ebfd648f0f8b 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -3,6 +3,7 @@ from torch_geometric.data import Data, HeteroData from torch_geometric.loader import LinkNeighborLoader +from torch_geometric.testing import withPackage from torch_geometric.testing.feature_store import MyFeatureStore from torch_geometric.testing.graph_store import MyGraphStore @@ -181,6 +182,7 @@ def test_link_neighbor_loader_edge_label(): assert torch.all(batch.edge_label[10:] == 0) +@withPackage('pyg_lib') def test_temporal_heterogeneous_link_neighbor_loader(): data = HeteroData() diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index a9d08d53c278..bf517692c994 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -280,6 +280,7 @@ def forward(self, x, edge_index, edge_weight): assert torch.allclose(out1, out2, atol=1e-6) +@withPackage('pyg_lib') def test_temporal_heterogeneous_neighbor_loader_on_cora(get_dataset): dataset = get_dataset(name='Cora') data = dataset[0] @@ -380,6 +381,7 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore): 'author', 'to', 'paper'].edge_index.size()) +@withPackage('pyg_lib') @pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) @pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore, From 13028afd7488f104021786e5c3f4596533a2c8ca Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 13 Oct 2022 09:14:33 +0200 Subject: [PATCH 3/3] Add `CHANGELOG` note to contributing guidelines (#5674) --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 333b800a9299..e5c03abbfdc3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,6 +96,9 @@ Everytime you send a Pull Request, your commit will be built and checked against (which runs a set of additional but time-consuming tests) dependening on your needs. +3. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md?plain=1). + If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point. + ## Building Documentation To build the documentation: