From 2d77047811311544d4861fa9e80fd2ec60ab6d01 Mon Sep 17 00:00:00 2001 From: Buddhi Kothalawala Date: Fri, 7 Oct 2022 18:33:26 +1100 Subject: [PATCH 1/5] Fix the output shape of index-based aggregation example code (#5621) Since it has index values with 0, 1, and 2, the output shape should be [3, 64]. Running this example code will give you [3, 64] output. --- torch_geometric/nn/aggr/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 23be5a53ae16..cfb52742bd5c 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -32,7 +32,7 @@ class Aggregation(torch.nn.Module): # Assign each element to one of three sets: index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) - output = aggr(x, index) # Output shape: [4, 64] + output = aggr(x, index) # Output shape: [3, 64] Alternatively, aggregation can be achieved via a "compressed" index vector called :obj:`ptr`. Here, elements within the same set need to be grouped From d5e2e1eec7ff25435eef3a1904a9c19a3d4cfc82 Mon Sep 17 00:00:00 2001 From: Guohao Li Date: Sat, 8 Oct 2022 20:34:37 +0300 Subject: [PATCH 2/5] Support `in_channels` with `tuple` in `GENConv` for bipartite message passing (#5627) * changelog * Support in_channels with tuple for bipartite message passing * changelog * update Co-authored-by: Matthias Fey * support lazy init * Added test Co-authored-by: Guohao Li Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + test/nn/conv/test_gen_conv.py | 26 ++++++++++++ torch_geometric/nn/conv/gen_conv.py | 66 ++++++++++++++++++++--------- 3 files changed, 74 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 02fc79e616a6..ea2c5c36e49f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) - Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) diff --git a/test/nn/conv/test_gen_conv.py b/test/nn/conv/test_gen_conv.py index 76ee8c398ec3..aebedf817444 100644 --- a/test/nn/conv/test_gen_conv.py +++ b/test/nn/conv/test_gen_conv.py @@ -74,3 +74,29 @@ def test_gen_conv(aggr): jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj1.t()), out21, atol=1e-6) assert torch.allclose(jit((x1, x2), adj2.t()), out22, atol=1e-6) + + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + adj = adj1.sparse_resize((4, 2)) + conv = GENConv((8, 16), 32, aggr) + assert str(conv) == f'GENConv((8, 16), 32, aggr={aggr})' + out1 = conv((x1, x2), edge_index) + out2 = conv((x1, None), edge_index, size=(4, 2)) + assert out1.size() == (2, 32) + assert out2.size() == (2, 32) + assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist() + assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert conv((x1, None), adj.t()).tolist() == out2.tolist() + + if is_full_test(): + t = '(OptPairTensor, Tensor, Size) -> Tensor' + jit = torch.jit.script(conv.jittable(t)) + assert jit((x1, x2), edge_index).tolist() == out1.tolist() + assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist() + assert jit((x1, None), edge_index, + size=(4, 2)).tolist() == out2.tolist() + + t = '(OptPairTensor, SparseTensor, Size) -> Tensor' + jit = torch.jit.script(conv.jittable(t)) + assert jit((x1, x2), adj.t()).tolist() == out1.tolist() + assert jit((x1, None), adj.t()).tolist() == out2.tolist() diff --git a/torch_geometric/nn/conv/gen_conv.py b/torch_geometric/nn/conv/gen_conv.py index 013003c54573..0c7f2ecbb910 100644 --- a/torch_geometric/nn/conv/gen_conv.py +++ b/torch_geometric/nn/conv/gen_conv.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from torch import Tensor from torch.nn import ( @@ -62,8 +62,10 @@ class GENConv(MessagePassing): ogbn_proteins_deepgcn.py>`_. Args: - in_channels (int): Size of each input sample, or :obj:`-1` to derive - the size from the first input(s) to the forward method. + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. out_channels (int): Size of each output sample. aggr (str, optional): The aggregation scheme to use (:obj:`"softmax"`, :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, :obj:`max`). @@ -101,21 +103,36 @@ class GENConv(MessagePassing): - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ - def __init__(self, in_channels: int, out_channels: int, - aggr: str = 'softmax', t: float = 1.0, learn_t: bool = False, - p: float = 1.0, learn_p: bool = False, msg_norm: bool = False, - learn_msg_scale: bool = False, norm: str = 'batch', - num_layers: int = 2, eps: float = 1e-7, **kwargs): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + aggr: str = 'softmax', + t: float = 1.0, + learn_t: bool = False, + p: float = 1.0, + learn_p: bool = False, + msg_norm: bool = False, + learn_msg_scale: bool = False, + norm: str = 'batch', + num_layers: int = 2, + eps: float = 1e-7, + **kwargs, + ): # Backward compatibility: + semi_grad = True if aggr == 'softmax_sg' else False aggr = 'softmax' if aggr == 'softmax_sg' else aggr aggr = 'powermean' if aggr == 'power' else aggr - aggr_kwargs = {} - if aggr == 'softmax': - aggr_kwargs = dict(t=t, learn=learn_t) - elif aggr == 'powermean': - aggr_kwargs = dict(p=p, learn=learn_p) + aggr_kwargs = kwargs.get('aggr_kwargs', {}) + + # Override args of aggregator if `aggr_kwargs` is specified + if aggr_kwargs == {}: + if aggr == 'softmax': + aggr_kwargs = dict(t=t, learn=learn_t, semi_grad=semi_grad) + elif aggr == 'powermean': + aggr_kwargs = dict(p=p, learn=learn_p) super().__init__(aggr=aggr, aggr_kwargs=aggr_kwargs, **kwargs) @@ -123,19 +140,28 @@ def __init__(self, in_channels: int, out_channels: int, self.out_channels = out_channels self.eps = eps - channels = [in_channels] + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + channels = [in_channels[0]] for i in range(num_layers - 1): - channels.append(in_channels * 2) + channels.append(out_channels * 2) channels.append(out_channels) self.mlp = MLP(channels, norm=norm) - self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None + if in_channels[0] != in_channels[1] and in_channels[0] != -1: + self.lin_r = Linear(in_channels[1], in_channels[0], bias=True) + + if msg_norm: + self.msg_norm = MessageNorm(learn_msg_scale) def reset_parameters(self): reset(self.mlp) self.aggr_module.reset_parameters() - if self.msg_norm is not None: + if hasattr(self, 'msg_norm'): self.msg_norm.reset_parameters() + if hasattr(self, 'lin_r'): + self.proj_r.reset_parameters() def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: @@ -155,11 +181,13 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) - if self.msg_norm is not None: - out = self.msg_norm(x[0], out) + if hasattr(self, 'msg_norm'): + out = self.msg_norm(x[1] if x[1] is not None else x[0], out) x_r = x[1] if x_r is not None: + if hasattr(self, 'lin_r'): + x_r = self.lin_r(x_r) out = out + x_r return self.mlp(out) From a6b92fd8e3fcdaf2041e203343a966a264586069 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Mon, 10 Oct 2022 15:44:40 +0800 Subject: [PATCH 3/5] Add ratio adjustment in `RandomLinkSplit` when not enough negative edges exist (#5642) * add handling for not enough negative edges in random link split * add readme * Update torch_geometric/transforms/random_link_split.py Co-authored-by: Matthias Fey * Update torch_geometric/transforms/random_link_split.py Co-authored-by: Matthias Fey * update tests * typos * typo Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + test/transforms/test_random_link_split.py | 14 ++++++++++++++ torch_geometric/transforms/random_link_split.py | 16 +++++++++++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea2c5c36e49f..0b53c18f8c76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) - Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603)) diff --git a/test/transforms/test_random_link_split.py b/test/transforms/test_random_link_split.py index b4b049ac3c88..21269800a28e 100644 --- a/test/transforms/test_random_link_split.py +++ b/test/transforms/test_random_link_split.py @@ -196,3 +196,17 @@ def test_random_link_split_on_undirected_hetero_data(): rev_edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].is_undirected() + + +def test_random_link_split_insufficient_negative_edges(): + edge_index = torch.tensor([[0, 0, 1, 1, 2, 2], [1, 3, 0, 2, 0, 1]]) + data = Data(edge_index=edge_index, num_nodes=4) + + transform = RandomLinkSplit(num_val=0.34, num_test=0.34, + is_undirected=False, neg_sampling_ratio=2, + split_labels=True) + train_data, val_data, test_data = transform(data) + + assert train_data.neg_edge_label_index.size() == (2, 2) + assert val_data.neg_edge_label_index.size() == (2, 2) + assert test_data.neg_edge_label_index.size() == (2, 2) diff --git a/torch_geometric/transforms/random_link_split.py b/torch_geometric/transforms/random_link_split.py index 4b4bfc6fe817..ad23d160ee2b 100644 --- a/torch_geometric/transforms/random_link_split.py +++ b/torch_geometric/transforms/random_link_split.py @@ -1,3 +1,4 @@ +import warnings from copy import copy from typing import List, Optional, Union @@ -80,7 +81,7 @@ class RandomLinkSplit(BaseTransform): The reverse edge types of :obj:`edge_types` in case of operating on :class:`~torch_geometric.data.HeteroData` objects. This will ensure that edges of the reverse direction will be - splitted accordingly to prevent any data leakage. + split accordingly to prevent any data leakage. Can be :obj:`None` in case no reverse connection exists. (default: :obj:`None`) """ @@ -168,6 +169,7 @@ def __call__(self, data: Union[Data, HeteroData]): num_test = int(num_test * perm.numel()) num_train = perm.numel() - num_val - num_test + if num_train <= 0: raise ValueError("Insufficient number of edges for training") @@ -208,6 +210,18 @@ def __call__(self, data: Union[Data, HeteroData]): num_neg_samples=num_neg, method='sparse') + # Adjust ratio if not enough negative edges exist + if neg_edge_index.size(1) < num_neg: + num_neg_found = neg_edge_index.size(1) + ratio = num_neg_found / num_neg + warnings.warn( + f"There are not enough negative edges to satisfy " + "the provided sampling ratio. The ratio will be " + f"adjusted to {ratio:.2f}.") + num_neg_train = int((num_neg_train / num_neg) * num_neg_found) + num_neg_val = int((num_neg_val / num_neg) * num_neg_found) + num_neg_test = num_neg_found - num_neg_train - num_neg_val + # Create labels: if num_disjoint > 0: train_edges = train_edges[:num_disjoint] From 3733006ff51335c9d147e95410fd2171c9c7481f Mon Sep 17 00:00:00 2001 From: Guohao Li Date: Tue, 11 Oct 2022 17:24:17 +0300 Subject: [PATCH 4/5] Refactor `GENConv` linear layers for lazy init and fix doc for bipartite graphs (#5641) * changelog * fix gen doc * changelog * update * Refactor linear layers * add test * Fix edge linear layer * update * update * update Co-authored-by: Matthias Fey * update Co-authored-by: Matthias Fey * update Co-authored-by: Jinu Sunil * update Co-authored-by: Jinu Sunil * update Co-authored-by: Guohao Li Co-authored-by: Matthias Fey Co-authored-by: Jinu Sunil --- CHANGELOG.md | 2 +- test/nn/conv/test_gen_conv.py | 15 +++- torch_geometric/nn/conv/gen_conv.py | 103 +++++++++++++++++++--------- 3 files changed, 87 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b53c18f8c76..2ae1619b691f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,8 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) -- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) - Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) diff --git a/test/nn/conv/test_gen_conv.py b/test/nn/conv/test_gen_conv.py index aebedf817444..7bf5d4a592f1 100644 --- a/test/nn/conv/test_gen_conv.py +++ b/test/nn/conv/test_gen_conv.py @@ -16,7 +16,7 @@ def test_gen_conv(aggr): adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) - conv = GENConv(16, 32, aggr) + conv = GENConv(16, 32, aggr, edge_dim=16) assert conv.__repr__() == f'GENConv(16, 32, aggr={aggr})' out11 = conv(x1, edge_index) assert out11.size() == (4, 32) @@ -88,6 +88,19 @@ def test_gen_conv(aggr): assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() + value = torch.randn(row.size(0), 4) + adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 2)) + conv = GENConv((-1, -1), 32, aggr, edge_dim=-1) + assert str(conv) == f'GENConv((-1, -1), 32, aggr={aggr})' + out1 = conv((x1, x2), edge_index, value) + out2 = conv((x1, None), edge_index, value, size=(4, 2)) + assert out1.size() == (2, 32) + assert out2.size() == (2, 32) + assert conv((x1, x2), edge_index, value, + size=(4, 2)).tolist() == out1.tolist() + assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert conv((x1, None), adj.t()).tolist() == out2.tolist() + if is_full_test(): t = '(OptPairTensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) diff --git a/torch_geometric/nn/conv/gen_conv.py b/torch_geometric/nn/conv/gen_conv.py index 0c7f2ecbb910..07e9e4d8ace6 100644 --- a/torch_geometric/nn/conv/gen_conv.py +++ b/torch_geometric/nn/conv/gen_conv.py @@ -11,6 +11,7 @@ ) from torch_sparse import SparseTensor +from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.norm import MessageNorm @@ -67,9 +68,10 @@ class GENConv(MessagePassing): A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. - aggr (str, optional): The aggregation scheme to use (:obj:`"softmax"`, - :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, :obj:`max`). - (default: :obj:`"softmax"`) + aggr (string or Aggregation, optional): The aggregation scheme to use. + Any aggregation of :obj:`torch_geometric.nn.aggr` can be used, + (:obj:`"softmax"`, :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, + :obj:`max`). (default: :obj:`"softmax"`) t (float, optional): Initial inverse temperature for softmax aggregation. (default: :obj:`1.0`) learn_t (bool, optional): If set to :obj:`True`, will learn the value @@ -88,15 +90,24 @@ class GENConv(MessagePassing): :obj:`"layer"`, :obj:`"instance"`) (default: :obj:`batch`) num_layers (int, optional): The number of MLP layers. (default: :obj:`2`) + expansion (int, optional): The expansion factor of hidden channels in + MLP layers. (default: :obj:`2`) eps (float, optional): The epsilon value of the message construction function. (default: :obj:`1e-7`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + edge_dim (int, optional): Edge feature dimensionality. If set to + :obj:`None`, Edge feature dimensionality is expected to match + the `out_channels`. Other-wise, edge features are linearly + transformed to match `out_channels` of node feature dimensionality. + (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GenMessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or - :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in})` + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge attributes :math:`(|\mathcal{E}|, D)` *(optional)* @@ -107,7 +118,7 @@ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, - aggr: str = 'softmax', + aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax', t: float = 1.0, learn_t: bool = False, p: float = 1.0, @@ -116,7 +127,10 @@ def __init__( learn_msg_scale: bool = False, norm: str = 'batch', num_layers: int = 2, + expansion: int = 2, eps: float = 1e-7, + bias: bool = False, + edge_dim: Optional[int] = None, **kwargs, ): @@ -125,16 +139,15 @@ def __init__( aggr = 'softmax' if aggr == 'softmax_sg' else aggr aggr = 'powermean' if aggr == 'power' else aggr - aggr_kwargs = kwargs.get('aggr_kwargs', {}) - # Override args of aggregator if `aggr_kwargs` is specified - if aggr_kwargs == {}: + if 'aggr_kwargs' not in kwargs: if aggr == 'softmax': - aggr_kwargs = dict(t=t, learn=learn_t, semi_grad=semi_grad) + kwargs['aggr_kwargs'] = dict(t=t, learn=learn_t, + semi_grad=semi_grad) elif aggr == 'powermean': - aggr_kwargs = dict(p=p, learn=learn_p) + kwargs['aggr_kwargs'] = dict(p=p, learn=learn_p) - super().__init__(aggr=aggr, aggr_kwargs=aggr_kwargs, **kwargs) + super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels @@ -143,14 +156,29 @@ def __init__( if isinstance(in_channels, int): in_channels = (in_channels, in_channels) - channels = [in_channels[0]] + if in_channels[0] != out_channels: + self.lin_src = Linear(in_channels[0], out_channels, bias=bias) + + if edge_dim is not None and edge_dim != out_channels: + self.lin_edge = Linear(edge_dim, out_channels, bias=bias) + + if isinstance(self.aggr_module, MultiAggregation): + aggr_out_channels = self.aggr_module.get_out_channels(out_channels) + else: + aggr_out_channels = out_channels + + if aggr_out_channels != out_channels: + self.lin_aggr_out = Linear(aggr_out_channels, out_channels, + bias=bias) + + if in_channels[1] != out_channels: + self.lin_dst = Linear(in_channels[1], out_channels, bias=bias) + + channels = [out_channels] for i in range(num_layers - 1): - channels.append(out_channels * 2) + channels.append(out_channels * expansion) channels.append(out_channels) - self.mlp = MLP(channels, norm=norm) - - if in_channels[0] != in_channels[1] and in_channels[0] != -1: - self.lin_r = Linear(in_channels[1], in_channels[0], bias=True) + self.mlp = MLP(channels, norm=norm, bias=bias) if msg_norm: self.msg_norm = MessageNorm(learn_msg_scale) @@ -160,8 +188,14 @@ def reset_parameters(self): self.aggr_module.reset_parameters() if hasattr(self, 'msg_norm'): self.msg_norm.reset_parameters() - if hasattr(self, 'lin_r'): - self.proj_r.reset_parameters() + if hasattr(self, 'lin_src'): + self.lin_src.reset_parameters() + if hasattr(self, 'lin_edge'): + self.lin_edge.reset_parameters() + if hasattr(self, 'lin_aggr_out'): + self.lin_aggr_out.reset_parameters() + if hasattr(self, 'lin_dst'): + self.lin_dst.reset_parameters() def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: @@ -169,26 +203,33 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, if isinstance(x, Tensor): x: OptPairTensor = (x, x) - # Node and edge feature dimensionalites need to match. - if isinstance(edge_index, Tensor): - if edge_attr is not None: - assert x[0].size(-1) == edge_attr.size(-1) - elif isinstance(edge_index, SparseTensor): + if hasattr(self, 'lin_src'): + x = (self.lin_src(x[0]), x[1]) + + if isinstance(edge_index, SparseTensor): edge_attr = edge_index.storage.value() - if edge_attr is not None: - assert x[0].size(-1) == edge_attr.size(-1) + + if edge_attr is not None and hasattr(self, 'lin_edge'): + edge_attr = self.lin_edge(edge_attr) + + # Node and edge feature dimensionalites need to match. + if edge_attr is not None: + assert x[0].size(-1) == edge_attr.size(-1) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) + if hasattr(self, 'lin_aggr_out'): + out = self.lin_aggr_out(out) + if hasattr(self, 'msg_norm'): out = self.msg_norm(x[1] if x[1] is not None else x[0], out) - x_r = x[1] - if x_r is not None: - if hasattr(self, 'lin_r'): - x_r = self.lin_r(x_r) - out = out + x_r + x_dst = x[1] + if x_dst is not None: + if hasattr(self, 'lin_dst'): + x_dst = self.lin_dst(x_dst) + out = out + x_dst return self.mlp(out) From 6bda07558398bcbe8611c17995ed244796341cb0 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Wed, 12 Oct 2022 06:39:42 -0700 Subject: [PATCH 5/5] Minor fix for `pyg-lib` usage in `HeteroLinear` and `RGCNConv` (#5510) pyg_lib pathway is chosen if cuda is available and pyg_lib is available. but if the input is not cuda we should not be using the pyg_lib pathway. simple fix without this a ton of CI fails for testing rgcnconv and heterolinear w/ cpu inputs [errors.txt](https://github.com/pyg-team/pytorch_geometric/files/9629902/errors.txt) Co-authored-by: Matthias Fey --- .github/workflows/full_testing.yml | 6 ++++- .github/workflows/testing.yml | 4 +-- test/nn/conv/test_rgcn_conv.py | 38 +++++++++++++-------------- test/nn/test_to_hetero_transformer.py | 4 +-- torch_geometric/nn/conv/rgcn_conv.py | 3 ++- torch_geometric/nn/dense/linear.py | 2 +- 6 files changed, 30 insertions(+), 27 deletions(-) diff --git a/.github/workflows/full_testing.yml b/.github/workflows/full_testing.yml index a3820e94fd2a..08af480ece87 100644 --- a/.github/workflows/full_testing.yml +++ b/.github/workflows/full_testing.yml @@ -40,9 +40,13 @@ jobs: - name: Install internal dependencies run: | - pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ matrix.torch-version }}+cpu.html pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html + - name: Install pyg-lib + if: ${{ runner.os == 'Linux' }} # pyg-lib is currently only available on Linux. + run: | + pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ matrix.torch-version }}+cpu.html + - name: Install main package run: | pip install -e .[full,test] diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2e10b18c8a66..561bc245f8f3 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,10 +16,8 @@ jobs: matrix: os: [ubuntu-latest] python-version: [3.9] - torch-version: [1.11.0, 1.12.0] + torch-version: [1.12.0] include: - - torch-version: 1.11.0 - torchvision-version: 0.12.0 - torch-version: 1.12.0 torchvision-version: 0.13.0 diff --git a/test/nn/conv/test_rgcn_conv.py b/test/nn/conv/test_rgcn_conv.py index 214ca64c57ad..cae1554b2f95 100644 --- a/test/nn/conv/test_rgcn_conv.py +++ b/test/nn/conv/test_rgcn_conv.py @@ -26,10 +26,10 @@ def test_rgcn_conv_equality(conf): edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3]) torch.manual_seed(12345) - conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks) + conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum') torch.manual_seed(12345) - conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks) + conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum') out1 = conv1(x1, edge_index, edge_type) out2 = conv2(x1, edge_index, edge_type) @@ -54,50 +54,50 @@ def test_rgcn_conv(cls, conf): row, col = edge_index adj = SparseTensor(row=row, col=col, value=edge_type, sparse_sizes=(4, 4)) - conv = cls(4, 32, 2, num_bases, num_blocks) + conv = cls(4, 32, 2, num_bases, num_blocks, aggr='sum') assert conv.__repr__() == f'{cls.__name__}(4, 32, num_relations=2)' out1 = conv(x1, edge_index, edge_type) assert out1.size() == (4, 32) - assert conv(x1, adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv(x1, adj.t()), out1, atol=1e-6) if num_blocks is None: out2 = conv(None, edge_index, edge_type) assert out2.size() == (4, 32) - assert conv(None, adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv(None, adj.t()), out2, atol=1e-6) if is_full_test(): t = '(OptTensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit(x1, edge_index, edge_type).tolist() == out1.tolist() + assert torch.allclose(jit(x1, edge_index, edge_type), out1) if num_blocks is None: - assert jit(idx1, edge_index, edge_type).tolist() == out2.tolist() - assert jit(None, edge_index, edge_type).tolist() == out2.tolist() + assert torch.allclose(jit(idx1, edge_index, edge_type), out2) + assert torch.allclose(jit(None, edge_index, edge_type), out2) t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit(x1, adj.t()).tolist() == out1.tolist() + assert torch.allclose(jit(x1, adj.t()), out1) if num_blocks is None: - assert jit(idx1, adj.t()).tolist() == out2.tolist() - assert jit(None, adj.t()).tolist() == out2.tolist() + assert torch.allclose(jit(idx1, adj.t()), out2, atol=1e-6) + assert torch.allclose(jit(None, adj.t()), out2, atol=1e-6) adj = adj.sparse_resize((4, 2)) - conv = cls((4, 16), 32, 2, num_bases, num_blocks) + conv = cls((4, 16), 32, 2, num_bases, num_blocks, aggr='sum') assert conv.__repr__() == f'{cls.__name__}((4, 16), 32, num_relations=2)' out1 = conv((x1, x2), edge_index, edge_type) assert out1.size() == (2, 32) - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6) if num_blocks is None: out2 = conv((None, idx2), edge_index, edge_type) assert out2.size() == (2, 32) assert torch.allclose(conv((idx1, idx2), edge_index, edge_type), out2) - assert conv((None, idx2), adj.t()).tolist() == out2.tolist() - assert conv((idx1, idx2), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((None, idx2), adj.t()), out2, atol=1e-6) + assert torch.allclose(conv((idx1, idx2), adj.t()), out2, atol=1e-6) if is_full_test(): t = '(Tuple[OptTensor, Tensor], Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit((x1, x2), edge_index, edge_type).tolist() == out1.tolist() + assert torch.allclose(jit((x1, x2), edge_index, edge_type), out1) if num_blocks is None: assert torch.allclose(jit((None, idx2), edge_index, edge_type), out2) @@ -106,7 +106,7 @@ def test_rgcn_conv(cls, conf): t = '(Tuple[OptTensor, Tensor], SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert jit((x1, x2), adj.t()).tolist() == out1.tolist() + assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) if num_blocks is None: - assert jit((None, idx2), adj.t()).tolist() == out2.tolist() - assert jit((idx1, idx2), adj.t()).tolist() == out2.tolist() + assert torch.allclose(jit((None, idx2), adj.t()), out2, atol=1e-6) + assert torch.allclose(jit((idx1, idx2), adj.t()), out2, atol=1e-6) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index a516ab34a51e..61d566715d9a 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -312,7 +312,7 @@ def test_to_hetero_with_basic_model(): class GraphConv(MessagePassing): def __init__(self, in_channels, out_channels): - super().__init__(aggr='mean') + super().__init__(aggr='sum') self.lin = Linear(in_channels, out_channels, bias=False) def reset_parameters(self): @@ -351,7 +351,7 @@ def test_to_hetero_and_rgcn_equal_output(): edge_type[(row >= 6) & (col < 6)] = 2 assert edge_type.min() == 0 - conv = RGCNConv(16, 32, num_relations=3) + conv = RGCNConv(16, 32, num_relations=3, aggr='sum') out1 = conv(x, edge_index, edge_type) # Run `to_hetero`: diff --git a/torch_geometric/nn/conv/rgcn_conv.py b/torch_geometric/nn/conv/rgcn_conv.py index 7e5e85af6813..6044b5113ae3 100644 --- a/torch_geometric/nn/conv/rgcn_conv.py +++ b/torch_geometric/nn/conv/rgcn_conv.py @@ -112,7 +112,7 @@ def __init__( ): kwargs.setdefault('aggr', aggr) super().__init__(node_dim=0, **kwargs) - self._WITH_PYG_LIB = torch.cuda.is_available() and _WITH_PYG_LIB + self._WITH_PYG_LIB = _WITH_PYG_LIB if num_bases is not None and num_blocks is not None: raise ValueError('Can not apply both basis-decomposition and ' @@ -263,6 +263,7 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor: if edge_type_ptr is not None: + # TODO Re-weight according to edge type degree for `aggr=mean`. return segment_matmul(x_j, edge_type_ptr, self.weight) return x_j diff --git a/torch_geometric/nn/dense/linear.py b/torch_geometric/nn/dense/linear.py index d8e222ade7dd..1f8e5206dcf6 100644 --- a/torch_geometric/nn/dense/linear.py +++ b/torch_geometric/nn/dense/linear.py @@ -220,7 +220,7 @@ def __init__(self, in_channels: int, out_channels: int, num_types: int, self.is_sorted = is_sorted self.kwargs = kwargs - self._WITH_PYG_LIB = torch.cuda.is_available() and _WITH_PYG_LIB + self._WITH_PYG_LIB = _WITH_PYG_LIB if self._WITH_PYG_LIB: self.weight = torch.nn.Parameter(