Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor GENConv to rely on new Aggregation #4866

Merged
merged 5 commits into from
Jun 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))
Expand Down
2 changes: 1 addition & 1 deletion test/nn/conv/test_gen_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('aggr', ['softmax', 'softmax_sg', 'power'])
@pytest.mark.parametrize('aggr', ['softmax', 'powermean'])
def test_gen_conv(aggr):
x1 = torch.randn(4, 16)
x2 = torch.randn(2, 16)
Expand Down
71 changes: 15 additions & 56 deletions torch_geometric/nn/conv/gen_conv.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Dropout,
InstanceNorm1d,
LayerNorm,
Parameter,
ReLU,
Sequential,
)
from torch_scatter import scatter, scatter_softmax
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
Expand Down Expand Up @@ -72,8 +68,8 @@ class GENConv(MessagePassing):
dimensionalities.
out_channels (int): Size of each output sample.
aggr (str, optional): The aggregation scheme to use (:obj:`"softmax"`,
:obj:`"softmax_sg"`, :obj:`"power"`, :obj:`"add"`, :obj:`"mean"`,
:obj:`max`). (default: :obj:`"softmax"`)
:obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, :obj:`max`).
(default: :obj:`"softmax"`)
t (float, optional): Initial inverse temperature for softmax
aggregation. (default: :obj:`1.0`)
learn_t (bool, optional): If set to :obj:`True`, will learn the value
Expand Down Expand Up @@ -113,16 +109,22 @@ def __init__(self, in_channels: int, out_channels: int,
learn_msg_scale: bool = False, norm: str = 'batch',
num_layers: int = 2, eps: float = 1e-7, **kwargs):

kwargs.setdefault('aggr', None)
super().__init__(**kwargs)
# Backward compatibility:
aggr = 'softmax' if aggr == 'softmax_sg' else aggr
aggr = 'powermean' if aggr == 'power' else aggr

aggr_kwargs = {}
if aggr == 'softmax':
aggr_kwargs = dict(t=t, learn=learn_t)
elif aggr == 'powermean':
aggr_kwargs = dict(p=p, learn=learn_p)

super().__init__(aggr=aggr, aggr_kwargs=aggr_kwargs, **kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.aggr = aggr
self.eps = eps

assert aggr in ['softmax', 'softmax_sg', 'power', 'add', 'mean', 'max']

channels = [in_channels]
for i in range(num_layers - 1):
channels.append(in_channels * 2)
Expand All @@ -131,32 +133,15 @@ def __init__(self, in_channels: int, out_channels: int,

self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None

self.initial_t = t
self.initial_p = p

if learn_t and aggr == 'softmax':
self.t = Parameter(torch.Tensor([t]), requires_grad=True)
else:
self.t = t

if learn_p:
self.p = Parameter(torch.Tensor([p]), requires_grad=True)
else:
self.p = p

def reset_parameters(self):
reset(self.mlp)
self.aggr_module.reset_parameters()
if self.msg_norm is not None:
self.msg_norm.reset_parameters()
if self.t and isinstance(self.t, Tensor):
self.t.data.fill_(self.initial_t)
if self.p and isinstance(self.p, Tensor):
self.p.data.fill_(self.initial_p)

def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
""""""

if isinstance(x, Tensor):
x: OptPairTensor = (x, x)

Expand All @@ -183,33 +168,7 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,

def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
msg = x_j if edge_attr is None else x_j + edge_attr
return F.relu(msg) + self.eps

def aggregate(self, inputs: Tensor, index: Tensor,
dim_size: Optional[int] = None) -> Tensor:

if self.aggr == 'softmax':
out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
return scatter(inputs * out, index, dim=self.node_dim,
dim_size=dim_size, reduce='sum')

elif self.aggr == 'softmax_sg':
out = scatter_softmax(inputs * self.t, index,
dim=self.node_dim).detach()
return scatter(inputs * out, index, dim=self.node_dim,
dim_size=dim_size, reduce='sum')

elif self.aggr == 'power':
min_value, max_value = 1e-7, 1e1
torch.clamp_(inputs, min_value, max_value)
out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
dim_size=dim_size, reduce='mean')
torch.clamp_(out, min_value, max_value)
return torch.pow(out, 1 / self.p)

else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
return msg.relu() + self.eps

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def __init__(
def reset_parameters(self):
if self.project:
self.lin.reset_parameters()
if self.aggr is None:
self.lstm.reset_parameters()
self.aggr_module.reset_parameters()
self.lin_l.reset_parameters()
if self.root_weight:
self.lin_r.reset_parameters()
Expand Down