Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_geometric.nn.aggr package with base class #4687

Merged
merged 16 commits into from
May 25, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687))
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))
- Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628))
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
Expand Down
47 changes: 47 additions & 0 deletions test/nn/aggr/test_aggr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch

from torch_geometric.nn import (
MaxAggr,
MeanAggr,
MinAggr,
PowermeanAggr,
SoftmaxAggr,
StdAggr,
SumAggr,
VarAggr,
)


@pytest.mark.parametrize('aggr', [MeanAggr, MaxAggr, MinAggr, SumAggr])
def test_basic_aggr(aggr):
src = torch.randn(6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
out = aggr()(src, index)
assert out.shape[0] == index.unique().shape[0]


@pytest.mark.parametrize('aggr', [SoftmaxAggr, PowermeanAggr])
def test_gen_aggr(aggr):
src = torch.randn(6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
for learn in [True, False]:
if issubclass(aggr, SoftmaxAggr):
aggregator = aggr(t=1, learn=learn)
elif issubclass(aggr, PowermeanAggr):
aggregator = aggr(p=1, learn=learn)
else:
raise NotImplementedError
out = aggregator(src, index)
if any(map(lambda x: x.requires_grad, aggregator.parameters())):
out.mean().backward()
for param in aggregator.parameters():
assert not torch.isnan(param.grad).any()


@pytest.mark.parametrize('aggr', [VarAggr, StdAggr])
def test_stat_aggr(aggr):
src = torch.randn(6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
out = aggr()(src, index)
assert out.shape[0] == index.unique().shape[0]
1 change: 1 addition & 0 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .data_parallel import DataParallel
from .to_hetero_transformer import to_hetero
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .aggr import * # noqa
from .conv import * # noqa
from .norm import * # noqa
from .glob import * # noqa
Expand Down
23 changes: 23 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .base import BaseAggr
from .aggr import (
MeanAggr,
SumAggr,
MaxAggr,
MinAggr,
SoftmaxAggr,
PowerMeanAggr,
VarAggr,
StdAggr,
)

__all__ = classes = [
'BaseAggr',
'MeanAggr',
'SumAggr',
'MaxAggr',
'MinAggr',
'SoftmaxAggr',
'PowerMeanAggr',
'VarAggr',
'StdAggr',
]
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
90 changes: 90 additions & 0 deletions torch_geometric/nn/aggr/aggr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.aggr import BaseAggr
from torch_geometric.utils import softmax


class MeanAggr(BaseAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:
return self.reduce(x, index, dim_size, ptr, dim, reduce='mean')


class SumAggr(BaseAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:
return self.reduce(x, index, dim_size, ptr, dim, reduce='sum')


class MaxAggr(BaseAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:
return self.reduce(x, index, dim_size, ptr, dim, reduce='max')


class MinAggr(BaseAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:
return self.reduce(x, index, dim_size, ptr, dim, reduce='min')


class SoftmaxAggr(BaseAggr):
def __init__(self, t: float = 1.0, learn: bool = False):
# TODO Learn distinct `t` per channel.
super().__init__()
self._init_t = t
self.t = Parameter(torch.Tensor(1)) if learn else t
self.reset_parameters()

def reset_parameters(self):
if isinstance(self.t, Tensor):
self.t.data.fill_(self._init_t)

def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:

if not isinstance(self.t, (int, float)) or self.t != 1:
alpha = x * self.t
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, dim_size, ptr, dim, reduce='sum')


class PowerMeanAggr(BaseAggr):
def __init__(self, p: float = 1.0, learn: bool = False):
# TODO Learn distinct `p` per channel.
super().__init__()
self._init_p = p
self.p = Parameter(torch.Tensor(1)) if learn else p
self.reset_parameters()

if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:

out = self.reduce(x, index, dim_size, ptr, dim, reduce='mean')
if isinstance(self.p, (int, float)) and self.p == 1:
return out
return out.pow(1. / self.p)


class VarAggr(BaseAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:

mean = self.reduce(x, index, dim_size, ptr, dim, reduce='mean')
mean_2 = self.reduce(x * x, index, dim_size, ptr, dim, reduce='mean')
return mean_2 - mean * mean


class StdAggr(VarAggr):
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:

var = self(x, index, ptr, dim, dim_size)
return torch.sqrt(var.relu() + 1e-5)
40 changes: 40 additions & 0 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter, segment_csr


class BaseAggr(torch.nn.Module, ABC):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
r"""An abstract base class for writing aggregations."""
@abstractmethod
def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor:
pass
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

def reset_parameters(self):
pass

@staticmethod
def reduce(x: Tensor, index: Tensor, dim_size: Optional[int] = None,
ptr: Optional[Tensor] = None, dim: int = -2,
reduce: str = 'add') -> Tensor:

if ptr is not None:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


###############################################################################


def expand_left(src: torch.Tensor, dim: int, dims: int) -> torch.Tensor:
for _ in range(dims + dim if dim < 0 else dim):
src = src.unsqueeze(0)
return src
3 changes: 2 additions & 1 deletion torch_geometric/transforms/base_transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from typing import Any


class BaseTransform:
class BaseTransform(ABC):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
r"""An abstract base class for writing transforms.

Transforms are a general way to modify and customize
Expand Down