From 23d4eef3c378c3dd878d58e842d04a18d0a413e9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 09:19:42 +0000 Subject: [PATCH 1/8] update --- test/nn/aggr/test_fused.py | 3 +- test/nn/aggr/test_multi.py | 4 +- torch_geometric/nn/aggr/fused.py | 16 ++++--- torch_geometric/nn/aggr/multi.py | 74 +++++++++++++++++++++++--------- 4 files changed, 68 insertions(+), 29 deletions(-) diff --git a/test/nn/aggr/test_fused.py b/test/nn/aggr/test_fused.py index 62e66cc71493..7e0684a20e42 100644 --- a/test/nn/aggr/test_fused.py +++ b/test/nn/aggr/test_fused.py @@ -46,7 +46,8 @@ def test_fused_aggregation(aggrs): aggrs = ['sum', 'mean', 'max', 'std'] aggrs = [aggregation_resolver(aggr) for aggr in aggrs] - fused_aggr = FusedAggregation(aggrs) + from torch_geometric.nn.aggr import MultiAggregation + fused_aggr = MultiAggregation(aggrs) num_warmups, num_steps = (500, 1000) diff --git a/test/nn/aggr/test_multi.py b/test/nn/aggr/test_multi.py index 920b58760cd0..ccfa70dbfbc4 100644 --- a/test/nn/aggr/test_multi.py +++ b/test/nn/aggr/test_multi.py @@ -32,9 +32,11 @@ def test_multi_aggr(multi_aggr_tuple): assert str(aggr) == ('MultiAggregation([\n' ' MeanAggregation(),\n' ' SumAggregation(),\n' - ' MaxAggregation()\n' + ' MaxAggregation(),\n' f"], mode={aggr_kwargs['mode']})") out = aggr(x, index) assert torch.allclose(out, aggr(x, ptr=ptr)) assert out.size() == (4, expand * x.size(1)) + + # TODO test JIT support diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py index 3778037247b5..40f2c04a01c0 100644 --- a/torch_geometric/nn/aggr/fused.py +++ b/torch_geometric/nn/aggr/fused.py @@ -3,8 +3,8 @@ import torch from torch import Tensor -from torch_geometric.nn import ( - Aggregation, +from torch_geometric.nn.aggr.base import Aggregation +from torch_geometric.nn.aggr.basic import ( MaxAggregation, MeanAggregation, MinAggregation, @@ -52,6 +52,8 @@ class FusedAggregation(Aggregation): Args: aggrs (list): The list of aggregation schemes to use. + cat (bool, optional): Whether to concatenate the output. + (default: :obj:`True`) """ # We can fuse all aggregations together that rely on `scatter` directives. FUSABLE_AGGRS = { @@ -89,8 +91,9 @@ class FusedAggregation(Aggregation): StdAggregation: 'pow_sum', } - def __init__(self, aggrs: List[Union[Aggregation, str]]): + def __init__(self, aggrs: List[Union[Aggregation, str]], cat: bool = True): super().__init__() + self.cat = cat if not isinstance(aggrs, (list, tuple)): raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " @@ -288,6 +291,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, ####################################################################### - out = torch.cat(outs, dim=-1) - - return out + if self.cat: + return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] + else: + return outs diff --git a/torch_geometric/nn/aggr/multi.py b/torch_geometric/nn/aggr/multi.py index 49b210a091c0..e52bc17774b2 100644 --- a/torch_geometric/nn/aggr/multi.py +++ b/torch_geometric/nn/aggr/multi.py @@ -6,6 +6,7 @@ from torch.nn import Linear, MultiheadAttention from torch_geometric.nn.aggr import Aggregation +from torch_geometric.nn.aggr.fused import FusedAggregation from torch_geometric.nn.resolver import aggregation_resolver @@ -68,10 +69,30 @@ def __init__( for aggr, aggr_kwargs in zip(aggrs, aggrs_kwargs) ]) + # Divide the set into fusable and non-fusable aggregations: + fused_aggrs = [] + self.fused_out_index = [] + self.non_fused_aggrs = [] + self.non_fused_out_index = [] + for i, aggr in enumerate(self.aggrs): + if aggr.__class__ in FusedAggregation.FUSABLE_AGGRS: + fused_aggrs.append(aggr) + self.fused_out_index.append(i) + else: + self.non_fused_aggrs.append(aggr) + self.non_fused_out_index.append(i) + + if len(fused_aggrs) > 0: + self.fused_aggr = FusedAggregation(fused_aggrs, cat=False) + else: + self.fused_aggr = None + self.mode = mode - mode_kwargs = copy.copy(mode_kwargs or {}) + mode_kwargs = copy.copy(mode_kwargs) or {} + self.in_channels = mode_kwargs.pop('in_channels', None) self.out_channels = mode_kwargs.pop('out_channels', None) + if mode == 'proj' or mode == 'attn': if len(aggrs) == 1: raise ValueError("Multiple aggregations are required for " @@ -83,7 +104,7 @@ def __init__( f"and `out_channels` specified.") if isinstance(self.in_channels, int): - self.in_channels = (self.in_channels, ) * len(aggrs) + self.in_channels = [self.in_channels] * len(aggrs) if mode == 'proj': self.lin = Linear( @@ -92,7 +113,7 @@ def __init__( **mode_kwargs, ) - if mode == 'attn': + elif mode == 'attn': self.lin_heads = torch.nn.ModuleList([ Linear(channels, self.out_channels) for channels in self.in_channels @@ -113,18 +134,17 @@ def __init__( def reset_parameters(self): for aggr in self.aggrs: aggr.reset_parameters() - if hasattr(self, 'lin'): + if self.mode == 'proj': self.lin.reset_parameters() - if hasattr(self, 'lin_heads'): + if self.mode == 'attn': for lin in self.lin_heads: lin.reset_parameters() - if hasattr(self, 'multihead_attn'): self.multihead_attn._reset_parameters() def get_out_channels(self, in_channels: int) -> int: if self.out_channels is not None: return self.out_channels - # TODO: Support having customized `out_channels` in each aggregation + # TODO Support having customized `out_channels` in each aggregation. if self.mode == 'cat': return in_channels * len(self.aggrs) return in_channels @@ -132,18 +152,34 @@ def get_out_channels(self, in_channels: int) -> int: def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: - outs = [] - for aggr in self.aggrs: - outs.append(aggr(x, index, ptr, dim_size, dim)) - return self.combine(outs) if len(outs) > 1 else outs[0] + # `FusedAggregation` is currently limited to two-dimensional inputs: + if index is None or x.dim() != 2 or self.fused_aggr is None: + outs = [aggr(x, index, ptr, dim_size, dim) for aggr in self.aggrs] + return self.combine(outs) + + outs = [None] * len(self.aggrs) + + fused_outs = self.fused_aggr(x, index, ptr, dim_size, dim) + for i, out in zip(self.fused_out_index, fused_outs): + outs[i] = out + + for i, aggr in zip(self.non_fused_out_index, self.non_fused_aggrs): + outs[i] = aggr(x, index, ptr, dim_size, dim) + + return self.combine(outs) def combine(self, inputs: List[Tensor]) -> Tensor: - if self.mode in ['cat', 'proj']: - out = torch.cat(inputs, dim=-1) - return self.lin(out) if hasattr(self, 'lin') else out + if len(inputs) == 1: + return inputs[0] + + if self.mode == 'cat': + return torch.cat(inputs, dim=-1) + + if self.mode == 'proj': + return self.lin(torch.cat(inputs, dim=-1)) - if hasattr(self, 'multihead_attn'): + if self.mode == 'attn': x = torch.stack( [head(x) for x, head in zip(inputs, self.lin_heads)], dim=0, @@ -158,9 +194,5 @@ def combine(self, inputs: List[Tensor]) -> Tensor: raise ValueError(f"Combine mode '{self.mode}' is not supported.") def __repr__(self) -> str: - args = [f' {aggr}' for aggr in self.aggrs] - return '{}([\n{}\n], mode={})'.format( - self.__class__.__name__, - ',\n'.join(args), - self.mode, - ) + aggrs = ',\n'.join([f' {aggr}' for aggr in self.aggrs]) + ',\n' + return f'{self.__class__.__name__}([\n{aggrs}], mode={self.mode})' From 060a2858b873323f14152c2e98db5759dd91e008 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 09:20:52 +0000 Subject: [PATCH 2/8] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 284f69db881c..1d2264f1eab3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added -- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036)) +- Allow for fused aggregation in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040)) - Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934)) - Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007)) - Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834)) From bb7a86fd4910624ffccd62c926ba739a3b5a0fed Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 09:21:08 +0000 Subject: [PATCH 3/8] changelog --- test/nn/aggr/test_fused.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/nn/aggr/test_fused.py b/test/nn/aggr/test_fused.py index 7e0684a20e42..62e66cc71493 100644 --- a/test/nn/aggr/test_fused.py +++ b/test/nn/aggr/test_fused.py @@ -46,8 +46,7 @@ def test_fused_aggregation(aggrs): aggrs = ['sum', 'mean', 'max', 'std'] aggrs = [aggregation_resolver(aggr) for aggr in aggrs] - from torch_geometric.nn.aggr import MultiAggregation - fused_aggr = MultiAggregation(aggrs) + fused_aggr = FusedAggregation(aggrs) num_warmups, num_steps = (500, 1000) From 04df79259fcfed47fbc1bf98c90eb673c7a1e71c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 09:22:11 +0000 Subject: [PATCH 4/8] typo --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2264f1eab3..d8324e0f66dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added -- Allow for fused aggregation in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040)) +- Allow for fused aggregations in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040)) - Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934)) - Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007)) - Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834)) From 5dcdbd780104734b701aa80b7fe030ea1e1afe78 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 13:51:57 +0000 Subject: [PATCH 5/8] update --- torch_geometric/nn/aggr/basic.py | 7 +- torch_geometric/nn/aggr/fused.py | 194 ++++++++++++++++++++----------- torch_geometric/nn/aggr/multi.py | 25 ++-- 3 files changed, 141 insertions(+), 85 deletions(-) diff --git a/torch_geometric/nn/aggr/basic.py b/torch_geometric/nn/aggr/basic.py index af1cae157769..f798f8834d93 100644 --- a/torch_geometric/nn/aggr/basic.py +++ b/torch_geometric/nn/aggr/basic.py @@ -180,10 +180,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, alpha = x * t if not self.learn and self.semi_grad: - with torch.no_grad(): - alpha = softmax(alpha, index, ptr, dim_size, dim) - else: - alpha = softmax(alpha, index, ptr, dim_size, dim) + alpha = alpha.detach() + + alpha = softmax(alpha, index, ptr, dim_size, dim) return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') def __repr__(self) -> str: diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py index 40f2c04a01c0..9ba35a40325e 100644 --- a/torch_geometric/nn/aggr/fused.py +++ b/torch_geometric/nn/aggr/fused.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor @@ -52,8 +52,6 @@ class FusedAggregation(Aggregation): Args: aggrs (list): The list of aggregation schemes to use. - cat (bool, optional): Whether to concatenate the output. - (default: :obj:`True`) """ # We can fuse all aggregations together that rely on `scatter` directives. FUSABLE_AGGRS = { @@ -82,18 +80,17 @@ class FusedAggregation(Aggregation): # Map aggregations to `reduce` options in `scatter` directives. REDUCE = { - SumAggregation: 'sum', - MeanAggregation: 'sum', - MinAggregation: 'amin', - MaxAggregation: 'amax', - MulAggregation: 'prod', - VarAggregation: 'pow_sum', - StdAggregation: 'pow_sum', + 'SumAggregation': 'sum', + 'MeanAggregation': 'sum', + 'MinAggregation': 'amin', + 'MaxAggregation': 'amax', + 'MulAggregation': 'prod', + 'VarAggregation': 'pow_sum', + 'StdAggregation': 'pow_sum', } - def __init__(self, aggrs: List[Union[Aggregation, str]], cat: bool = True): + def __init__(self, aggrs: List[Union[Aggregation, str]]): super().__init__() - self.cat = cat if not isinstance(aggrs, (list, tuple)): raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " @@ -104,10 +101,14 @@ def __init__(self, aggrs: List[Union[Aggregation, str]], cat: bool = True): f"not be empty.") aggrs = [aggregation_resolver(aggr) for aggr in aggrs] - self.aggr_cls = [aggr.__class__ for aggr in aggrs] - self.aggr_index = {cls: i for i, cls in enumerate(self.aggr_cls)} - - for cls in self.aggr_cls: + aggr_classes = [aggr.__class__ for aggr in aggrs] + self.aggr_names = [cls.__name__ for cls in aggr_classes] + self.aggr_index: Dict[str, int] = { + name: i + for i, name in enumerate(self.aggr_names) + } + + for cls in aggr_classes: if cls not in self.FUSABLE_AGGRS: raise ValueError(f"Received aggregation '{cls.__name__}' in " f"'{self.__class__.__name__}' which is not " @@ -115,13 +116,13 @@ def __init__(self, aggrs: List[Union[Aggregation, str]], cat: bool = True): # Check whether we need to compute degree information: self.need_degree = False - for cls in self.aggr_cls: + for cls in aggr_classes: if cls in self.DEGREE_BASED_AGGRS: self.need_degree = True # Check whether we need to compute mask information: self.requires_mask = False - for cls in self.aggr_cls: + for cls in aggr_classes: if cls in self.MASK_REQUIRED_AGGRS: self.requires_mask = True @@ -130,63 +131,85 @@ def __init__(self, aggrs: List[Union[Aggregation, str]], cat: bool = True): # outputs from other aggregators. self.reduce_ops: List[Optional[str]] = [] # Determine which `(Aggregator, index)` to use as intermediate output: - self.lookup_ops: List[Optional[Tuple[Any, int]]] = [] + self.lookup_ops: List[Optional[Tuple[str, int]]] = [] - for cls in self.aggr_cls: - if cls == MeanAggregation: + for name in self.aggr_names: + if name == 'MeanAggregation': # Directly use output of `SumAggregation`: - if SumAggregation in self.aggr_index: + if 'SumAggregation' in self.aggr_index: self.reduce_ops.append(None) - self.lookup_ops.append( - (SumAggregation, self.aggr_index[SumAggregation])) + self.lookup_ops.append(( + 'SumAggregation', + self.aggr_index['SumAggregation'], + )) else: - self.reduce_ops.append(self.REDUCE[cls]) + self.reduce_ops.append(self.REDUCE[name]) self.lookup_ops.append(None) - elif cls == VarAggregation: - if MeanAggregation in self.aggr_index: - self.reduce_ops.append(self.REDUCE[cls]) - self.lookup_ops.append( - (MeanAggregation, self.aggr_index[MeanAggregation])) - elif SumAggregation in self.aggr_index: - self.reduce_ops.append(self.REDUCE[cls]) - self.lookup_ops.append( - (SumAggregation, self.aggr_index[SumAggregation])) + elif name == 'VarAggregation': + if 'MeanAggregation' in self.aggr_index: + self.reduce_ops.append(self.REDUCE[name]) + self.lookup_ops.append(( + 'MeanAggregation', + self.aggr_index['MeanAggregation'], + )) + elif 'SumAggregation' in self.aggr_index: + self.reduce_ops.append(self.REDUCE[name]) + self.lookup_ops.append(( + 'SumAggregation', + self.aggr_index['SumAggregation'], + )) else: - self.reduce_ops.append(self.REDUCE[cls]) + self.reduce_ops.append(self.REDUCE[name]) self.lookup_ops.append(None) - elif cls == StdAggregation: + elif name == 'StdAggregation': # Directly use output of `VarAggregation`: - if VarAggregation in self.aggr_index: + if 'VarAggregation' in self.aggr_index: self.reduce_ops.append(None) - self.lookup_ops.append( - (VarAggregation, self.aggr_index[VarAggregation])) - elif MeanAggregation in self.aggr_index: - self.reduce_ops.append(self.REDUCE[cls]) - self.lookup_ops.append( - (MeanAggregation, self.aggr_index[MeanAggregation])) - elif SumAggregation in self.aggr_index: - self.reduce_ops.append(self.REDUCE[cls]) - self.lookup_ops.append( - (SumAggregation, self.aggr_index[SumAggregation])) + self.lookup_ops.append(( + 'VarAggregation', + self.aggr_index['VarAggregation'], + )) + elif 'MeanAggregation' in self.aggr_index: + self.reduce_ops.append(self.REDUCE[name]) + self.lookup_ops.append(( + 'MeanAggregation', + self.aggr_index['MeanAggregation'], + )) + elif 'SumAggregation' in self.aggr_index: + self.reduce_ops.append(self.REDUCE[name]) + self.lookup_ops.append(( + 'SumAggregation', + self.aggr_index['SumAggregation'], + )) else: - self.reduce_ops.append(self.REDUCE[cls]) + self.reduce_ops.append(self.REDUCE[name]) self.lookup_ops.append(None) else: - self.reduce_ops.append(self.REDUCE[cls]) + self.reduce_ops.append(self.REDUCE[name]) self.lookup_ops.append(None) def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, - dim: int = -2) -> Tensor: + dim: int = -2) -> List[Tensor]: # Assert two-dimensional input for now to simplify computation: # TODO refactor this to support any dimension. self.assert_index_present(index) self.assert_two_dimensional_input(x, dim) + assert index is not None + + if dim_size is None: + if ptr is not None: + dim_size = ptr.numel() - 1 + else: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + + count: Optional[Tensor] = None + mask: Optional[Tensor] = None if self.need_degree: count = x.new_zeros(dim_size) count.scatter_add_(0, index, x.new_ones(x.size(0))) @@ -210,6 +233,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, if reduce is None: outs.append(None) continue + assert isinstance(reduce, str) src = x * x if reduce == 'pow_sum' else x reduce = 'sum' if reduce == 'pow_sum' else reduce @@ -226,72 +250,104 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, out = x.new_full((dim_size, num_feats), fill_value) out.scatter_reduce_(0, index, src, reduce, include_self=True) if fill_value != 0.0: + assert mask is not None out = out.masked_fill(mask.view(-1, 1), 0.0) outs.append(out) ####################################################################### # Compute `MeanAggregation` first to be able to re-use it: - i = self.aggr_index.get(MeanAggregation) + i = self.aggr_index.get('MeanAggregation') if i is not None: + assert count is not None + if self.lookup_ops[i] is None: sum_ = outs[i] else: - tmp_aggr, j = self.lookup_ops[i] - assert tmp_aggr == SumAggregation + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + assert tmp_aggr == 'SumAggregation' + sum_ = outs[j] + assert sum_ is not None outs[i] = sum_ / count # Compute `VarAggregation` second to be able to re-use it: - i = self.aggr_index.get(VarAggregation) + i = self.aggr_index.get('VarAggregation') if i is not None: + assert count is not None + if self.lookup_ops[i] is None: mean = x.new_zeros(dim_size, num_feats) mean.scatter_reduce_(0, index, x, 'sum', include_self=True) mean = mean / count else: - tmp_aggr, j = self.lookup_ops[i] - if tmp_aggr == SumAggregation: - mean = outs[j] / count - elif tmp_aggr == MeanAggregation: + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + + if tmp_aggr == 'SumAggregation': + sum_ = outs[j] + assert sum_ is not None + mean = sum_ / count + elif tmp_aggr == 'MeanAggregation': mean = outs[j] else: raise NotImplementedError pow_sum = outs[i] + + assert pow_sum is not None + assert mean is not None outs[i] = (pow_sum / count) - (mean * mean) # Compute `StdAggregation` last: - i = self.aggr_index.get(StdAggregation) + i = self.aggr_index.get('StdAggregation') if i is not None: - var = None + var: Optional[Tensor] = None + pow_sum: Optional[Tensor] = None + mean: Optional[Tensor] = None + if self.lookup_ops[i] is None: pow_sum = outs[i] mean = x.new_zeros(dim_size, num_feats) mean.scatter_reduce_(0, index, x, 'sum', include_self=True) + assert count is not None mean = mean / count else: - tmp_aggr, j = self.lookup_ops[i] - if tmp_aggr == VarAggregation: + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + + if tmp_aggr == 'VarAggregation': var = outs[j] - elif tmp_aggr == SumAggregation: + elif tmp_aggr == 'SumAggregation': pow_sum = outs[i] - mean = outs[j] / count - elif tmp_aggr == MeanAggregation: + sum_ = outs[j] + assert sum_ is not None + assert count is not None + mean = sum_ / count + elif tmp_aggr == 'MeanAggregation': pow_sum = outs[i] mean = outs[j] else: raise NotImplementedError if var is None: + assert pow_sum is not None + assert count is not None + assert mean is not None var = (pow_sum / count) - (mean * mean) outs[i] = (var.relu() + 1e-5).sqrt() ####################################################################### - if self.cat: - return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] - else: - return outs + vals: List[Tensor] = [] + for out in outs: + assert out is not None + vals.append(out) + + return vals diff --git a/torch_geometric/nn/aggr/multi.py b/torch_geometric/nn/aggr/multi.py index e52bc17774b2..a6fac5cd28cf 100644 --- a/torch_geometric/nn/aggr/multi.py +++ b/torch_geometric/nn/aggr/multi.py @@ -70,20 +70,20 @@ def __init__( ]) # Divide the set into fusable and non-fusable aggregations: - fused_aggrs = [] - self.fused_out_index = [] - self.non_fused_aggrs = [] - self.non_fused_out_index = [] + fused_aggrs: List[Aggregation] = [] + self.fused_out_index: List[int] = [] + # self.non_fused_aggrs: List[Aggregation] = [] + self.is_fused_aggr: List[bool] = [] for i, aggr in enumerate(self.aggrs): if aggr.__class__ in FusedAggregation.FUSABLE_AGGRS: fused_aggrs.append(aggr) self.fused_out_index.append(i) + self.is_fused_aggr.append(True) else: - self.non_fused_aggrs.append(aggr) - self.non_fused_out_index.append(i) + self.is_fused_aggr.append(False) if len(fused_aggrs) > 0: - self.fused_aggr = FusedAggregation(fused_aggrs, cat=False) + self.fused_aggr = FusedAggregation(fused_aggrs) else: self.fused_aggr = None @@ -158,14 +158,15 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, outs = [aggr(x, index, ptr, dim_size, dim) for aggr in self.aggrs] return self.combine(outs) - outs = [None] * len(self.aggrs) + outs: List[Tensor] = [x] * len(self.aggrs) # Fill with dummy tensors. fused_outs = self.fused_aggr(x, index, ptr, dim_size, dim) for i, out in zip(self.fused_out_index, fused_outs): outs[i] = out - for i, aggr in zip(self.non_fused_out_index, self.non_fused_aggrs): - outs[i] = aggr(x, index, ptr, dim_size, dim) + for i, aggr in enumerate(self.aggrs): + if not self.is_fused_aggr[i]: + outs[i] = aggr(x, index, ptr, dim_size, dim) return self.combine(outs) @@ -176,10 +177,10 @@ def combine(self, inputs: List[Tensor]) -> Tensor: if self.mode == 'cat': return torch.cat(inputs, dim=-1) - if self.mode == 'proj': + if hasattr(self, 'lin'): return self.lin(torch.cat(inputs, dim=-1)) - if self.mode == 'attn': + if hasattr(self, 'multihead_attn'): x = torch.stack( [head(x) for x, head in zip(inputs, self.lin_heads)], dim=0, From ed28cecaf9f1e70a7ad13cd2ab2f52f17dc036ad Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 13:54:03 +0000 Subject: [PATCH 6/8] reset --- torch_geometric/nn/aggr/basic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/aggr/basic.py b/torch_geometric/nn/aggr/basic.py index f798f8834d93..af1cae157769 100644 --- a/torch_geometric/nn/aggr/basic.py +++ b/torch_geometric/nn/aggr/basic.py @@ -180,9 +180,10 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, alpha = x * t if not self.learn and self.semi_grad: - alpha = alpha.detach() - - alpha = softmax(alpha, index, ptr, dim_size, dim) + with torch.no_grad(): + alpha = softmax(alpha, index, ptr, dim_size, dim) + else: + alpha = softmax(alpha, index, ptr, dim_size, dim) return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') def __repr__(self) -> str: From 80b8df07691aaa0e67b088ee9bc32c625975182b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 13:58:30 +0000 Subject: [PATCH 7/8] update --- test/nn/aggr/test_fused.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/nn/aggr/test_fused.py b/test/nn/aggr/test_fused.py index 370a162db4c9..46d9f0cfcc22 100644 --- a/test/nn/aggr/test_fused.py +++ b/test/nn/aggr/test_fused.py @@ -26,7 +26,7 @@ def test_fused_aggregation(aggrs): aggr = FusedAggregation(aggrs) assert str(aggr) == 'FusedAggregation()' - out = aggr(x, index) + out = torch.cat(aggr(x, index), dim=-1) expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1) assert torch.allclose(out, expected) @@ -97,7 +97,7 @@ def test_fused_aggregation(aggrs): torch.cuda.synchronize() t_start = time.perf_counter() - out = fused_aggr(x, index, dim_size=num_nodes) + out = torch.cat(fused_aggr(x, index, dim_size=num_nodes), dim=-1) torch.cuda.synchronize() if i >= num_warmups: From d02deba48779118867904acce3b42ab5e07d981d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Nov 2022 14:00:37 +0000 Subject: [PATCH 8/8] update --- torch_geometric/nn/aggr/fused.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py index 9ba35a40325e..b8fc6b673496 100644 --- a/torch_geometric/nn/aggr/fused.py +++ b/torch_geometric/nn/aggr/fused.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import Tensor