From a73063112c37f4e973ee0c729871ecbf7a086c70 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 20 May 2022 09:26:36 -0700 Subject: [PATCH 01/15] initial commit --- torch_geometric/nn/aggr/__init__.py | 5 +++ torch_geometric/nn/aggr/aggr.py | 46 ++++++++++++++++++++ torch_geometric/transforms/base_transform.py | 3 +- 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 torch_geometric/nn/aggr/__init__.py create mode 100644 torch_geometric/nn/aggr/aggr.py diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py new file mode 100644 index 000000000000..cd7a169f6371 --- /dev/null +++ b/torch_geometric/nn/aggr/__init__.py @@ -0,0 +1,5 @@ +from .aggr import Aggr + +__all__ = classes = [ + 'Aggr', +] diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/aggr.py new file mode 100644 index 000000000000..08e3f2806238 --- /dev/null +++ b/torch_geometric/nn/aggr/aggr.py @@ -0,0 +1,46 @@ +from abc import ABC +from typing import Optional + +import torch +from torch import Tensor +from torch_scatter import scatter, segment_csr + + +class Aggr(torch.nn.Module, ABC): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + raise NotImplementedError + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + +class MeanAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + return segment_csr(x, ptr, reduce='mean') + else: + return scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') + + +############################################################################### + + +def expand_left(src: torch.Tensor, dim: int, dims: int) -> torch.Tensor: + for _ in range(dims + dim if dim < 0 else dim): + src = src.unsqueeze(0) + return src diff --git a/torch_geometric/transforms/base_transform.py b/torch_geometric/transforms/base_transform.py index 56f12fcfc9e7..8a3041c2cd7e 100644 --- a/torch_geometric/transforms/base_transform.py +++ b/torch_geometric/transforms/base_transform.py @@ -1,7 +1,8 @@ +from abc import ABC from typing import Any -class BaseTransform: +class BaseTransform(ABC): r"""An abstract base class for writing transforms. Transforms are a general way to modify and customize From a0972bf93a79cf96b8f2881af5952ff1ed51a886 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 20 May 2022 09:26:51 -0700 Subject: [PATCH 02/15] update --- torch_geometric/nn/aggr/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index cd7a169f6371..f51ed514e08c 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,5 +1,6 @@ -from .aggr import Aggr +from .aggr import Aggr, MeanAggr __all__ = classes = [ 'Aggr', + 'MeanAggr', ] From e10f7ce99ce671ba7ebb761ef690f0cf43141d96 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 20 May 2022 09:28:42 -0700 Subject: [PATCH 03/15] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b31ee021aaa..0933a7f67f03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687)) - Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676)) - Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628)) - Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653)) From 277df134acb85aae8fbedefb6028dc58fc49dc8b Mon Sep 17 00:00:00 2001 From: lightaime Date: Wed, 25 May 2022 16:11:41 +0300 Subject: [PATCH 04/15] Added basic aggrs, gen aggrs and pna aggrs --- torch_geometric/nn/__init__.py | 1 + torch_geometric/nn/aggr/aggr.py | 156 +++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/__init__.py b/torch_geometric/nn/__init__.py index 0550d4c07b37..06435828e66b 100644 --- a/torch_geometric/nn/__init__.py +++ b/torch_geometric/nn/__init__.py @@ -12,6 +12,7 @@ from .dense import * # noqa from .models import * # noqa from .functional import * # noqa +from .aggr import * # noqa __all__ = [ 'MetaLayer', diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/aggr.py index 08e3f2806238..46d26354d27e 100644 --- a/torch_geometric/nn/aggr/aggr.py +++ b/torch_geometric/nn/aggr/aggr.py @@ -3,7 +3,8 @@ import torch from torch import Tensor -from torch_scatter import scatter, segment_csr +from torch.nn import Parameter +from torch_scatter import scatter, scatter_softmax, segment_csr class Aggr(torch.nn.Module, ABC): @@ -37,6 +38,159 @@ def forward( return scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') +class MaxAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + return segment_csr(x, ptr, reduce='max') + else: + return scatter(x, index, dim=dim, dim_size=dim_size, reduce='max') + + +class MinAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + return segment_csr(x, ptr, reduce='min') + else: + return scatter(x, index, dim=dim, dim_size=dim_size, reduce='min') + + +class SumAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + return segment_csr(x, ptr, reduce='sum') + else: + return scatter(x, index, dim=dim, dim_size=dim_size, reduce='sum') + + +class SoftmaxAggr(Aggr): + def __init__(self, t: float = 1.0, learn: bool = False): + super().__init__() + self.init_t = t + self.learn = learn + if learn: + self.t = Parameter(torch.Tensor([t]), requires_grad=True) + else: + self.t = t + + def reset_parameters(self): + if self.t and isinstance(self.t, Tensor): + self.t.data.fill_(self.init_t) + + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + raise NotImplementedError + else: + if self.learn: + w = scatter_softmax(x * self.t, index, dim=dim) + else: + with torch.no_grad(): + w = scatter_softmax(x * self.t, index, dim=dim) + return scatter(x * w, index, dim=dim, dim_size=dim_size, + reduce='add') + + +class PowermeanAggr(Aggr): + def __init__(self, p: float = 1.0, learn: bool = False): + super().__init__() + self.init_p = p + if learn: + self.p = Parameter(torch.Tensor([p]), requires_grad=True) + else: + self.p = p + + def reset_parameters(self): + if self.p and isinstance(self.p, Tensor): + self.p.data.fill_(self.init_p) + + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + x = torch.pow(torch.abs(x), self.p) + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + return torch.pow(segment_csr(x, ptr, reduce='mean'), 1 / self.p) + else: + return torch.pow( + scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean'), + 1 / self.p) + + +class VarAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + mean = segment_csr(x, ptr, reduce='mean') + mean_squares = segment_csr(x * x, ptr, reduce='mean') + else: + mean = scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') + mean_squares = scatter(x * x, index, dim=dim, dim_size=dim_size, + reduce='mean') + return mean_squares - mean * mean + + +class StdAggr(Aggr): + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim: int = -2, + dim_size: Optional[None] = None, + ) -> Tensor: + if ptr is not None: + ptr = expand_left(ptr, dim=dim, dims=x.dim()) + mean = segment_csr(x, ptr, reduce='mean') + mean_squares = segment_csr(x * x, ptr, reduce='mean') + else: + mean = scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') + mean_squares = scatter(x * x, index, dim=dim, dim_size=dim_size, + reduce='mean') + return torch.sqrt(torch.relu(mean_squares - mean * mean) + 1e-5) + + ############################################################################### From 433aab66e655d88906d3bfee969e22edf8010fe8 Mon Sep 17 00:00:00 2001 From: lightaime Date: Wed, 25 May 2022 16:22:23 +0300 Subject: [PATCH 05/15] Formatted --- torch_geometric/nn/aggr/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index f51ed514e08c..41b7b70ccc81 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,6 +1,7 @@ -from .aggr import Aggr, MeanAggr +from .aggr import (Aggr, MeanAggr, MaxAggr, MinAggr, SumAggr, SoftmaxAggr, + PowermeanAggr, VarAggr, StdAggr) __all__ = classes = [ - 'Aggr', - 'MeanAggr', + 'Aggr', 'MeanAggr', 'MaxAggr', 'MinAggr', 'SumAggr', 'SoftmaxAggr', + 'PowermeanAggr', 'VarAggr', 'StdAggr' ] From e2076fd887b30b65e7cb4be97304eeed1dd92d69 Mon Sep 17 00:00:00 2001 From: lightaime Date: Wed, 25 May 2022 16:32:33 +0300 Subject: [PATCH 06/15] Formatted --- torch_geometric/nn/aggr/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 41b7b70ccc81..d2307f54360c 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -2,6 +2,13 @@ PowermeanAggr, VarAggr, StdAggr) __all__ = classes = [ - 'Aggr', 'MeanAggr', 'MaxAggr', 'MinAggr', 'SumAggr', 'SoftmaxAggr', - 'PowermeanAggr', 'VarAggr', 'StdAggr' + 'Aggr', + 'MeanAggr', + 'MaxAggr', + 'MinAggr', + 'SumAggr', + 'SoftmaxAggr', + 'PowermeanAggr', + 'VarAggr', + 'StdAggr', ] From b3b76747467e95934b77528c75d708a053fa4acf Mon Sep 17 00:00:00 2001 From: lightaime Date: Wed, 25 May 2022 17:55:14 +0300 Subject: [PATCH 07/15] Added test for aggr class --- test/nn/aggr/test_aggr.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 test/nn/aggr/test_aggr.py diff --git a/test/nn/aggr/test_aggr.py b/test/nn/aggr/test_aggr.py new file mode 100644 index 000000000000..af89fac4d58b --- /dev/null +++ b/test/nn/aggr/test_aggr.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from torch_geometric.nn import ( + MaxAggr, + MeanAggr, + MinAggr, + PowermeanAggr, + SoftmaxAggr, + StdAggr, + SumAggr, + VarAggr, +) + + +@pytest.mark.parametrize('aggr', [MeanAggr, MaxAggr, MinAggr, SumAggr]) +def test_basic_aggr(aggr): + src = torch.randn(6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + out = aggr()(src, index) + assert out.shape[0] == index.unique().shape[0] + + +@pytest.mark.parametrize('aggr', [SoftmaxAggr, PowermeanAggr]) +def test_gen_aggr(aggr): + src = torch.randn(6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + for learn in [True, False]: + if issubclass(aggr, SoftmaxAggr): + aggregator = aggr(t=1, learn=learn) + elif issubclass(aggr, PowermeanAggr): + aggregator = aggr(p=1, learn=learn) + else: + raise NotImplementedError + out = aggregator(src, index) + if any(map(lambda x: x.requires_grad, aggregator.parameters())): + out.mean().backward() + for param in aggregator.parameters(): + assert not torch.isnan(param.grad).any() + + +@pytest.mark.parametrize('aggr', [VarAggr, StdAggr]) +def test_stat_aggr(aggr): + src = torch.randn(6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + out = aggr()(src, index) + assert out.shape[0] == index.unique().shape[0] From c6fe443945ad9b4386d71b8b48bab2fc64ba5512 Mon Sep 17 00:00:00 2001 From: lightaime Date: Wed, 25 May 2022 17:56:06 +0300 Subject: [PATCH 08/15] Formatted --- torch_geometric/nn/aggr/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index d2307f54360c..eeeedd859816 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,5 +1,14 @@ -from .aggr import (Aggr, MeanAggr, MaxAggr, MinAggr, SumAggr, SoftmaxAggr, - PowermeanAggr, VarAggr, StdAggr) +from .aggr import ( + Aggr, + MeanAggr, + MaxAggr, + MinAggr, + SumAggr, + SoftmaxAggr, + PowermeanAggr, + VarAggr, + StdAggr, +) __all__ = classes = [ 'Aggr', From 74c3fda8552564b1b9bb8b20affb214dcc273bde Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 18:00:45 +0200 Subject: [PATCH 09/15] update --- torch_geometric/nn/__init__.py | 2 +- torch_geometric/nn/aggr/__init__.py | 8 +- torch_geometric/nn/aggr/aggr.py | 266 ++++++++-------------------- torch_geometric/nn/aggr/base.py | 39 ++++ 4 files changed, 122 insertions(+), 193 deletions(-) create mode 100644 torch_geometric/nn/aggr/base.py diff --git a/torch_geometric/nn/__init__.py b/torch_geometric/nn/__init__.py index 06435828e66b..dd66031497f2 100644 --- a/torch_geometric/nn/__init__.py +++ b/torch_geometric/nn/__init__.py @@ -4,6 +4,7 @@ from .data_parallel import DataParallel from .to_hetero_transformer import to_hetero from .to_hetero_with_bases_transformer import to_hetero_with_bases +from .aggr import * # noqa from .conv import * # noqa from .norm import * # noqa from .glob import * # noqa @@ -12,7 +13,6 @@ from .dense import * # noqa from .models import * # noqa from .functional import * # noqa -from .aggr import * # noqa __all__ = [ 'MetaLayer', diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index eeeedd859816..16c97947e89f 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,9 +1,9 @@ +from .base import BaseAggr from .aggr import ( - Aggr, MeanAggr, + SumAggr, MaxAggr, MinAggr, - SumAggr, SoftmaxAggr, PowermeanAggr, VarAggr, @@ -11,11 +11,11 @@ ) __all__ = classes = [ - 'Aggr', + 'BaseAggr', 'MeanAggr', + 'SumAggr', 'MaxAggr', 'MinAggr', - 'SumAggr', 'SoftmaxAggr', 'PowermeanAggr', 'VarAggr', diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/aggr.py index 46d26354d27e..82cafc6cb155 100644 --- a/torch_geometric/nn/aggr/aggr.py +++ b/torch_geometric/nn/aggr/aggr.py @@ -1,200 +1,90 @@ -from abc import ABC from typing import Optional import torch from torch import Tensor from torch.nn import Parameter -from torch_scatter import scatter, scatter_softmax, segment_csr - - -class Aggr(torch.nn.Module, ABC): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - raise NotImplementedError - - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' - - -class MeanAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - return segment_csr(x, ptr, reduce='mean') - else: - return scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') - - -class MaxAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - return segment_csr(x, ptr, reduce='max') - else: - return scatter(x, index, dim=dim, dim_size=dim_size, reduce='max') - - -class MinAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - return segment_csr(x, ptr, reduce='min') - else: - return scatter(x, index, dim=dim, dim_size=dim_size, reduce='min') - - -class SumAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - return segment_csr(x, ptr, reduce='sum') - else: - return scatter(x, index, dim=dim, dim_size=dim_size, reduce='sum') - - -class SoftmaxAggr(Aggr): + +from torch_geometric.nn.aggr import BaseAggr +from torch_geometric.utils import softmax + + +class MeanAggr(BaseAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + return self.reduce(x, index, dim_size, ptr, dim, reduce='mean') + + +class SumAggr(BaseAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + return self.reduce(x, index, dim_size, ptr, dim, reduce='sum') + + +class MaxAggr(BaseAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + return self.reduce(x, index, dim_size, ptr, dim, reduce='max') + + +class MinAggr(BaseAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + return self.reduce(x, index, dim_size, ptr, dim, reduce='min') + + +class SoftmaxAggr(BaseAggr): def __init__(self, t: float = 1.0, learn: bool = False): + # TODO Learn distinct `t` per channel. super().__init__() - self.init_t = t - self.learn = learn - if learn: - self.t = Parameter(torch.Tensor([t]), requires_grad=True) - else: - self.t = t + self._init_t = t + self.t = Parameter(torch.Tensor(1)) if learn else t + self.reset_parameters() def reset_parameters(self): - if self.t and isinstance(self.t, Tensor): - self.t.data.fill_(self.init_t) - - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - raise NotImplementedError - else: - if self.learn: - w = scatter_softmax(x * self.t, index, dim=dim) - else: - with torch.no_grad(): - w = scatter_softmax(x * self.t, index, dim=dim) - return scatter(x * w, index, dim=dim, dim_size=dim_size, - reduce='add') - - -class PowermeanAggr(Aggr): + if isinstance(self.t, Tensor): + self.t.data.fill_(self._init_t) + + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + + if not isinstance(self.t, (int, float)) or self.t != 1: + alpha = x * self.t + alpha = softmax(alpha, index, ptr, dim_size, dim) + return self.reduce(x * alpha, index, dim_size, ptr, dim, reduce='sum') + + +class PowerMeanAggr(BaseAggr): def __init__(self, p: float = 1.0, learn: bool = False): + # TODO Learn distinct `p` per channel. super().__init__() - self.init_p = p - if learn: - self.p = Parameter(torch.Tensor([p]), requires_grad=True) - else: - self.p = p + self._init_p = p + self.p = Parameter(torch.Tensor(1)) if learn else p + self.reset_parameters() - def reset_parameters(self): - if self.p and isinstance(self.p, Tensor): - self.p.data.fill_(self.init_p) - - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - x = torch.pow(torch.abs(x), self.p) - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - return torch.pow(segment_csr(x, ptr, reduce='mean'), 1 / self.p) - else: - return torch.pow( - scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean'), - 1 / self.p) - - -class VarAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - mean = segment_csr(x, ptr, reduce='mean') - mean_squares = segment_csr(x * x, ptr, reduce='mean') - else: - mean = scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') - mean_squares = scatter(x * x, index, dim=dim, dim_size=dim_size, - reduce='mean') - return mean_squares - mean * mean - - -class StdAggr(Aggr): - def forward( - self, - x: Tensor, - index: Tensor, - ptr: Optional[Tensor] = None, - dim: int = -2, - dim_size: Optional[None] = None, - ) -> Tensor: - if ptr is not None: - ptr = expand_left(ptr, dim=dim, dims=x.dim()) - mean = segment_csr(x, ptr, reduce='mean') - mean_squares = segment_csr(x * x, ptr, reduce='mean') - else: - mean = scatter(x, index, dim=dim, dim_size=dim_size, reduce='mean') - mean_squares = scatter(x * x, index, dim=dim, dim_size=dim_size, - reduce='mean') - return torch.sqrt(torch.relu(mean_squares - mean * mean) + 1e-5) - - -############################################################################### - - -def expand_left(src: torch.Tensor, dim: int, dims: int) -> torch.Tensor: - for _ in range(dims + dim if dim < 0 else dim): - src = src.unsqueeze(0) - return src + if isinstance(self.p, Tensor): + self.p.data.fill_(self._init_p) + + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + + out = self.reduce(x, index, dim_size, ptr, dim, reduce='mean') + if isinstance(self.p, (int, float)) and self.p == 1: + return out + return out.pow(1. / self.p) + + +class VarAggr(BaseAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + + mean = self.reduce(x, index, dim_size, ptr, dim, reduce='mean') + mean_2 = self.reduce(x * x, index, dim_size, ptr, dim, reduce='mean') + return mean_2 - mean * mean + + +class StdAggr(VarAggr): + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + + var = self(x, index, ptr, dim, dim_size) + return torch.sqrt(var.relu() + 1e-5) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py new file mode 100644 index 000000000000..4097352b9c36 --- /dev/null +++ b/torch_geometric/nn/aggr/base.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from torch import Tensor +from torch_scatter import scatter, segment_csr + + +class BaseAggr(torch.nn.Module, ABC): + r"""An abstract base class for writing aggregations.""" + @abstractmethod + def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + pass + + def reset_parameters(self): + pass + + def reduce(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + ptr: Optional[Tensor] = None, dim: int = -2, + reduce: str = 'add') -> Tensor: + + if ptr is not None: + ptr = expand_left(ptr, dim, dims=x.dim()) + return segment_csr(x, ptr, reduce=reduce) + + return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + +############################################################################### + + +def expand_left(src: torch.Tensor, dim: int, dims: int) -> torch.Tensor: + for _ in range(dims + dim if dim < 0 else dim): + src = src.unsqueeze(0) + return src From fd940a5a067fa9d4867d72bae6795831fa18a548 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 19:01:35 +0200 Subject: [PATCH 10/15] update --- torch_geometric/nn/aggr/__init__.py | 4 ++-- torch_geometric/nn/aggr/base.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 16c97947e89f..864834a139af 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -5,7 +5,7 @@ MaxAggr, MinAggr, SoftmaxAggr, - PowermeanAggr, + PowerMeanAggr, VarAggr, StdAggr, ) @@ -17,7 +17,7 @@ 'MaxAggr', 'MinAggr', 'SoftmaxAggr', - 'PowermeanAggr', + 'PowerMeanAggr', 'VarAggr', 'StdAggr', ] diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 4097352b9c36..155101ad54ad 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -16,7 +16,8 @@ def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, def reset_parameters(self): pass - def reduce(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, + @staticmethod + def reduce(x: Tensor, index: Tensor, dim_size: Optional[int] = None, ptr: Optional[Tensor] = None, dim: int = -2, reduce: str = 'add') -> Tensor: From 5c2a37df9be80818eb0d44fc52d144effe492089 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 19:58:51 +0200 Subject: [PATCH 11/15] update --- test/nn/aggr/test_aggr.py | 65 +++++++++++++----------- torch_geometric/nn/aggr/__init__.py | 8 +-- torch_geometric/nn/aggr/aggr.py | 79 ++++++++++++++++------------- torch_geometric/nn/aggr/base.py | 26 ++++++---- 4 files changed, 98 insertions(+), 80 deletions(-) diff --git a/test/nn/aggr/test_aggr.py b/test/nn/aggr/test_aggr.py index af89fac4d58b..e2b672e38906 100644 --- a/test/nn/aggr/test_aggr.py +++ b/test/nn/aggr/test_aggr.py @@ -5,7 +5,7 @@ MaxAggr, MeanAggr, MinAggr, - PowermeanAggr, + PowerMeanAggr, SoftmaxAggr, StdAggr, SumAggr, @@ -13,35 +13,38 @@ ) -@pytest.mark.parametrize('aggr', [MeanAggr, MaxAggr, MinAggr, SumAggr]) -def test_basic_aggr(aggr): - src = torch.randn(6, 64) - index = torch.tensor([0, 1, 0, 1, 2, 1]) - out = aggr()(src, index) - assert out.shape[0] == index.unique().shape[0] - - -@pytest.mark.parametrize('aggr', [SoftmaxAggr, PowermeanAggr]) -def test_gen_aggr(aggr): - src = torch.randn(6, 64) - index = torch.tensor([0, 1, 0, 1, 2, 1]) - for learn in [True, False]: - if issubclass(aggr, SoftmaxAggr): - aggregator = aggr(t=1, learn=learn) - elif issubclass(aggr, PowermeanAggr): - aggregator = aggr(p=1, learn=learn) - else: - raise NotImplementedError - out = aggregator(src, index) - if any(map(lambda x: x.requires_grad, aggregator.parameters())): - out.mean().backward() - for param in aggregator.parameters(): - assert not torch.isnan(param.grad).any() +@pytest.mark.parametrize( + 'Aggr', [MeanAggr, SumAggr, MaxAggr, MinAggr, VarAggr, StdAggr]) +def test_basic_aggr(Aggr): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + ptr = torch.tensor([0, 2, 5, 6]) + + aggr = Aggr() + assert str(aggr) == f'{Aggr.__name__}()' + + out = aggr(x, index) + assert out.size() == (3, x.size(1)) + assert torch.allclose(out, aggr(x, ptr=ptr)) + +@pytest.mark.parametrize('Aggr', [SoftmaxAggr, PowerMeanAggr]) +@pytest.mark.parametrize('learn', [True, False]) +def test_gen_aggr(Aggr, learn): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + ptr = torch.tensor([0, 2, 5, 6]) -@pytest.mark.parametrize('aggr', [VarAggr, StdAggr]) -def test_stat_aggr(aggr): - src = torch.randn(6, 64) - index = torch.tensor([0, 1, 0, 1, 2, 1]) - out = aggr()(src, index) - assert out.shape[0] == index.unique().shape[0] + aggr = Aggr(learn=learn) + assert str(aggr) == f'{Aggr.__name__}()' + + out = aggr(x, index) + assert out.size() == (3, x.size(1)) + assert torch.allclose(out, aggr(x, ptr=ptr)) + + if learn: + if any(map(lambda x: x.requires_grad, aggr.parameters())): + out.mean().backward() + for param in aggr.parameters(): + print(param.grad) + assert not torch.isnan(param.grad).any() diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 864834a139af..3703be951604 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -4,10 +4,10 @@ SumAggr, MaxAggr, MinAggr, - SoftmaxAggr, - PowerMeanAggr, VarAggr, StdAggr, + SoftmaxAggr, + PowerMeanAggr, ) __all__ = classes = [ @@ -16,8 +16,8 @@ 'SumAggr', 'MaxAggr', 'MinAggr', - 'SoftmaxAggr', - 'PowerMeanAggr', 'VarAggr', 'StdAggr', + 'SoftmaxAggr', + 'PowerMeanAggr', ] diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/aggr.py index 82cafc6cb155..fac900932de4 100644 --- a/torch_geometric/nn/aggr/aggr.py +++ b/torch_geometric/nn/aggr/aggr.py @@ -9,27 +9,50 @@ class MeanAggr(BaseAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - return self.reduce(x, index, dim_size, ptr, dim, reduce='mean') + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce='mean') class SumAggr(BaseAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - return self.reduce(x, index, dim_size, ptr, dim, reduce='sum') + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') class MaxAggr(BaseAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - return self.reduce(x, index, dim_size, ptr, dim, reduce='max') + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce='max') class MinAggr(BaseAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - return self.reduce(x, index, dim_size, ptr, dim, reduce='min') + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce='min') + + +class VarAggr(BaseAggr): + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') + mean_2 = self.reduce(x * x, index, ptr, dim_size, dim, reduce='mean') + return mean_2 - mean * mean + + +class StdAggr(VarAggr): + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + var = super().forward(x, index, ptr=ptr, dim_size=dim_size, dim=dim) + return torch.sqrt(var.relu() + 1e-5) class SoftmaxAggr(BaseAggr): @@ -44,13 +67,15 @@ def reset_parameters(self): if isinstance(self.t, Tensor): self.t.data.fill_(self._init_t) - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + alpha = x if not isinstance(self.t, (int, float)) or self.t != 1: alpha = x * self.t alpha = softmax(alpha, index, ptr, dim_size, dim) - return self.reduce(x * alpha, index, dim_size, ptr, dim, reduce='sum') + return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') class PowerMeanAggr(BaseAggr): @@ -64,27 +89,11 @@ def __init__(self, p: float = 1.0, learn: bool = False): if isinstance(self.p, Tensor): self.p.data.fill_(self._init_p) - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: - out = self.reduce(x, index, dim_size, ptr, dim, reduce='mean') + out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') if isinstance(self.p, (int, float)) and self.p == 1: return out return out.pow(1. / self.p) - - -class VarAggr(BaseAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - - mean = self.reduce(x, index, dim_size, ptr, dim, reduce='mean') - mean_2 = self.reduce(x * x, index, dim_size, ptr, dim, reduce='mean') - return mean_2 - mean * mean - - -class StdAggr(VarAggr): - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: - - var = self(x, index, ptr, dim, dim_size) - return torch.sqrt(var.relu() + 1e-5) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 155101ad54ad..1a3d48de95e2 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -9,23 +9,29 @@ class BaseAggr(torch.nn.Module, ABC): r"""An abstract base class for writing aggregations.""" @abstractmethod - def forward(self, x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2) -> Tensor: + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: pass def reset_parameters(self): pass - @staticmethod - def reduce(x: Tensor, index: Tensor, dim_size: Optional[int] = None, - ptr: Optional[Tensor] = None, dim: int = -2, - reduce: str = 'add') -> Tensor: + def reduce(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2, reduce: str = 'add') -> Tensor: + + assert index is not None or ptr is not None if ptr is not None: ptr = expand_left(ptr, dim, dims=x.dim()) return segment_csr(x, ptr, reduce=reduce) - return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) + if index is not None: + return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) + + raise ValueError(f"Error in '{self.__class__.__name__}': " + f"Both 'index' and 'ptr' are undefined") def __repr__(self) -> str: return f'{self.__class__.__name__}()' @@ -34,7 +40,7 @@ def __repr__(self) -> str: ############################################################################### -def expand_left(src: torch.Tensor, dim: int, dims: int) -> torch.Tensor: +def expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor: for _ in range(dims + dim if dim < 0 else dim): - src = src.unsqueeze(0) - return src + ptr = ptr.unsqueeze(0) + return ptr From 37aa3a2d617686c910b443cade4d463c6dbfea76 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 20:11:03 +0200 Subject: [PATCH 12/15] update --- test/nn/aggr/test_aggr.py | 8 +++----- torch_geometric/nn/aggr/aggr.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/test/nn/aggr/test_aggr.py b/test/nn/aggr/test_aggr.py index e2b672e38906..e18d1fdf15f6 100644 --- a/test/nn/aggr/test_aggr.py +++ b/test/nn/aggr/test_aggr.py @@ -43,8 +43,6 @@ def test_gen_aggr(Aggr, learn): assert torch.allclose(out, aggr(x, ptr=ptr)) if learn: - if any(map(lambda x: x.requires_grad, aggr.parameters())): - out.mean().backward() - for param in aggr.parameters(): - print(param.grad) - assert not torch.isnan(param.grad).any() + out.mean().backward() + for param in aggr.parameters(): + assert not torch.isnan(param.grad).any() diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/aggr.py index fac900932de4..19ec0b523a5d 100644 --- a/torch_geometric/nn/aggr/aggr.py +++ b/torch_geometric/nn/aggr/aggr.py @@ -96,4 +96,4 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') if isinstance(self.p, (int, float)) and self.p == 1: return out - return out.pow(1. / self.p) + return out.clamp_(min=0, max=100).pow(1. / self.p) From c391a2a6102022092547e71661984617556585ba Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 20:14:50 +0200 Subject: [PATCH 13/15] update --- test/nn/aggr/{test_aggr.py => test_basic.py} | 37 +++++++++--------- torch_geometric/nn/aggr/__init__.py | 38 +++++++++---------- torch_geometric/nn/aggr/base.py | 2 +- torch_geometric/nn/aggr/{aggr.py => basic.py} | 18 ++++----- 4 files changed, 49 insertions(+), 46 deletions(-) rename test/nn/aggr/{test_aggr.py => test_basic.py} (50%) rename torch_geometric/nn/aggr/{aggr.py => basic.py} (90%) diff --git a/test/nn/aggr/test_aggr.py b/test/nn/aggr/test_basic.py similarity index 50% rename from test/nn/aggr/test_aggr.py rename to test/nn/aggr/test_basic.py index e18d1fdf15f6..1e1dd07a1ed3 100644 --- a/test/nn/aggr/test_aggr.py +++ b/test/nn/aggr/test_basic.py @@ -2,41 +2,44 @@ import torch from torch_geometric.nn import ( - MaxAggr, - MeanAggr, - MinAggr, - PowerMeanAggr, - SoftmaxAggr, - StdAggr, - SumAggr, - VarAggr, + MaxAggregation, + MeanAggregation, + MinAggregation, + PowerMeanAggregation, + SoftmaxAggregation, + StdAggregation, + SumAggregation, + VarAggregation, ) -@pytest.mark.parametrize( - 'Aggr', [MeanAggr, SumAggr, MaxAggr, MinAggr, VarAggr, StdAggr]) -def test_basic_aggr(Aggr): +@pytest.mark.parametrize('Aggregation', [ + MeanAggregation, SumAggregation, MaxAggregation, MinAggregation, + VarAggregation, StdAggregation +]) +def test_basic_aggregation(Aggregation): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) - aggr = Aggr() - assert str(aggr) == f'{Aggr.__name__}()' + aggr = Aggregation() + assert str(aggr) == f'{Aggregation.__name__}()' out = aggr(x, index) assert out.size() == (3, x.size(1)) assert torch.allclose(out, aggr(x, ptr=ptr)) -@pytest.mark.parametrize('Aggr', [SoftmaxAggr, PowerMeanAggr]) +@pytest.mark.parametrize('Aggregation', + [SoftmaxAggregation, PowerMeanAggregation]) @pytest.mark.parametrize('learn', [True, False]) -def test_gen_aggr(Aggr, learn): +def test_gen_aggregation(Aggregation, learn): x = torch.randn(6, 16) index = torch.tensor([0, 0, 1, 1, 1, 2]) ptr = torch.tensor([0, 2, 5, 6]) - aggr = Aggr(learn=learn) - assert str(aggr) == f'{Aggr.__name__}()' + aggr = Aggregation(learn=learn) + assert str(aggr) == f'{Aggregation.__name__}()' out = aggr(x, index) assert out.size() == (3, x.size(1)) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 3703be951604..dbe1e4f086db 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -1,23 +1,23 @@ -from .base import BaseAggr -from .aggr import ( - MeanAggr, - SumAggr, - MaxAggr, - MinAggr, - VarAggr, - StdAggr, - SoftmaxAggr, - PowerMeanAggr, +from .base import Aggregation +from .basic import ( + MeanAggregation, + SumAggregation, + MaxAggregation, + MinAggregation, + VarAggregation, + StdAggregation, + SoftmaxAggregation, + PowerMeanAggregation, ) __all__ = classes = [ - 'BaseAggr', - 'MeanAggr', - 'SumAggr', - 'MaxAggr', - 'MinAggr', - 'VarAggr', - 'StdAggr', - 'SoftmaxAggr', - 'PowerMeanAggr', + 'Aggregation', + 'MeanAggregation', + 'SumAggregation', + 'MaxAggregation', + 'MinAggregation', + 'VarAggregation', + 'StdAggregation', + 'SoftmaxAggregation', + 'PowerMeanAggregation', ] diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 1a3d48de95e2..2aac8b923c3a 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -6,7 +6,7 @@ from torch_scatter import scatter, segment_csr -class BaseAggr(torch.nn.Module, ABC): +class Aggregation(torch.nn.Module, ABC): r"""An abstract base class for writing aggregations.""" @abstractmethod def forward(self, x: Tensor, index: Optional[Tensor] = None, *, diff --git a/torch_geometric/nn/aggr/aggr.py b/torch_geometric/nn/aggr/basic.py similarity index 90% rename from torch_geometric/nn/aggr/aggr.py rename to torch_geometric/nn/aggr/basic.py index 19ec0b523a5d..3b52fc225fad 100644 --- a/torch_geometric/nn/aggr/aggr.py +++ b/torch_geometric/nn/aggr/basic.py @@ -4,39 +4,39 @@ from torch import Tensor from torch.nn import Parameter -from torch_geometric.nn.aggr import BaseAggr +from torch_geometric.nn.aggr import Aggregation from torch_geometric.utils import softmax -class MeanAggr(BaseAggr): +class MeanAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='mean') -class SumAggr(BaseAggr): +class SumAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') -class MaxAggr(BaseAggr): +class MaxAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='max') -class MinAggr(BaseAggr): +class MinAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: return self.reduce(x, index, ptr, dim_size, dim, reduce='min') -class VarAggr(BaseAggr): +class VarAggregation(Aggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: @@ -46,7 +46,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, return mean_2 - mean * mean -class StdAggr(VarAggr): +class StdAggregation(VarAggregation): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: @@ -55,7 +55,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, return torch.sqrt(var.relu() + 1e-5) -class SoftmaxAggr(BaseAggr): +class SoftmaxAggregation(Aggregation): def __init__(self, t: float = 1.0, learn: bool = False): # TODO Learn distinct `t` per channel. super().__init__() @@ -78,7 +78,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') -class PowerMeanAggr(BaseAggr): +class PowerMeanAggregation(Aggregation): def __init__(self, p: float = 1.0, learn: bool = False): # TODO Learn distinct `p` per channel. super().__init__() From ab1086de05687efb3f31e9aa31812bd1d0e21436 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 20:18:46 +0200 Subject: [PATCH 14/15] docstring --- torch_geometric/nn/aggr/base.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 2aac8b923c3a..e2f19e946b77 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -12,6 +12,22 @@ class Aggregation(torch.nn.Module, ABC): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: + r""" + Args: + x (torch.Tensor): The source tensor. + index (torch.LongTensor, optional): The indices of elements for + applying the aggregation. + One of :obj:`index` or `ptr` must be defined. + (default: :obj:`None`) + ptr (torch.LongTensor, optional): If given, computes the + aggregation based on sorted inputs in CSR representation. + One of :obj:`index` or `ptr` must be defined. + (default: :obj:`None`) + dim_size (int, optional): The size of the output tensor at + dimension :obj:`dim` after aggregation. (default: :obj:`None`) + dim (int, optional): The dimension in which to aggregate. + (default: :obj:`-2`) + """ pass def reset_parameters(self): @@ -31,7 +47,7 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None, return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) raise ValueError(f"Error in '{self.__class__.__name__}': " - f"Both 'index' and 'ptr' are undefined") + f"One of 'index' or 'ptr' must be defined") def __repr__(self) -> str: return f'{self.__class__.__name__}()' From 33506930edd72cf5495149d69ea1a104a08b6d0d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 25 May 2022 20:19:14 +0200 Subject: [PATCH 15/15] typo --- torch_geometric/nn/aggr/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index e2f19e946b77..f36efc7bd09a 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -7,7 +7,7 @@ class Aggregation(torch.nn.Module, ABC): - r"""An abstract base class for writing aggregations.""" + r"""An abstract base class for implementing custom aggregations.""" @abstractmethod def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,