Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate FusedAggregation into MultiAggregation #6040

Merged
merged 9 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036))
- Allow for fused aggregations in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040))
- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934))
- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))
- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834))
Expand Down
4 changes: 2 additions & 2 deletions test/nn/aggr/test_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_fused_aggregation(aggrs):

aggr = FusedAggregation(aggrs)
assert str(aggr) == 'FusedAggregation()'
out = aggr(x, index)
out = torch.cat(aggr(x, index), dim=-1)

expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1)
assert torch.allclose(out, expected)
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_fused_aggregation(aggrs):
torch.cuda.synchronize()

t_start = time.perf_counter()
out = fused_aggr(x, index, dim_size=num_nodes)
out = torch.cat(fused_aggr(x, index, dim_size=num_nodes), dim=-1)

torch.cuda.synchronize()
if i >= num_warmups:
Expand Down
4 changes: 3 additions & 1 deletion test/nn/aggr/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def test_multi_aggr(multi_aggr_tuple):
assert str(aggr) == ('MultiAggregation([\n'
' MeanAggregation(),\n'
' SumAggregation(),\n'
' MaxAggregation()\n'
' MaxAggregation(),\n'
f"], mode={aggr_kwargs['mode']})")

out = aggr(x, index)
assert torch.allclose(out, aggr(x, ptr=ptr))
assert out.size() == (4, expand * x.size(1))

# TODO test JIT support
190 changes: 125 additions & 65 deletions torch_geometric/nn/aggr/fused.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.nn import (
Aggregation,
from torch_geometric.nn.aggr.base import Aggregation
from torch_geometric.nn.aggr.basic import (
MaxAggregation,
MeanAggregation,
MinAggregation,
Expand Down Expand Up @@ -80,13 +80,13 @@ class FusedAggregation(Aggregation):

# Map aggregations to `reduce` options in `scatter` directives.
REDUCE = {
SumAggregation: 'sum',
MeanAggregation: 'sum',
MinAggregation: 'amin',
MaxAggregation: 'amax',
MulAggregation: 'prod',
VarAggregation: 'pow_sum',
StdAggregation: 'pow_sum',
'SumAggregation': 'sum',
'MeanAggregation': 'sum',
'MinAggregation': 'amin',
'MaxAggregation': 'amax',
'MulAggregation': 'prod',
'VarAggregation': 'pow_sum',
'StdAggregation': 'pow_sum',
}

def __init__(self, aggrs: List[Union[Aggregation, str]]):
Expand All @@ -101,24 +101,28 @@ def __init__(self, aggrs: List[Union[Aggregation, str]]):
f"not be empty.")

aggrs = [aggregation_resolver(aggr) for aggr in aggrs]
self.aggr_cls = [aggr.__class__ for aggr in aggrs]
self.aggr_index = {cls: i for i, cls in enumerate(self.aggr_cls)}

for cls in self.aggr_cls:
aggr_classes = [aggr.__class__ for aggr in aggrs]
self.aggr_names = [cls.__name__ for cls in aggr_classes]
self.aggr_index: Dict[str, int] = {
name: i
for i, name in enumerate(self.aggr_names)
}

for cls in aggr_classes:
if cls not in self.FUSABLE_AGGRS:
raise ValueError(f"Received aggregation '{cls.__name__}' in "
f"'{self.__class__.__name__}' which is not "
f"fusable")

# Check whether we need to compute degree information:
self.need_degree = False
for cls in self.aggr_cls:
for cls in aggr_classes:
if cls in self.DEGREE_BASED_AGGRS:
self.need_degree = True

# Check whether we need to compute mask information:
self.requires_mask = False
for cls in self.aggr_cls:
for cls in aggr_classes:
if cls in self.MASK_REQUIRED_AGGRS:
self.requires_mask = True

Expand All @@ -127,63 +131,85 @@ def __init__(self, aggrs: List[Union[Aggregation, str]]):
# outputs from other aggregators.
self.reduce_ops: List[Optional[str]] = []
# Determine which `(Aggregator, index)` to use as intermediate output:
self.lookup_ops: List[Optional[Tuple[Any, int]]] = []
self.lookup_ops: List[Optional[Tuple[str, int]]] = []

for cls in self.aggr_cls:
if cls == MeanAggregation:
for name in self.aggr_names:
if name == 'MeanAggregation':
# Directly use output of `SumAggregation`:
if SumAggregation in self.aggr_index:
if 'SumAggregation' in self.aggr_index:
self.reduce_ops.append(None)
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
self.lookup_ops.append((
'SumAggregation',
self.aggr_index['SumAggregation'],
))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append(None)

elif cls == VarAggregation:
if MeanAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(MeanAggregation, self.aggr_index[MeanAggregation]))
elif SumAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
elif name == 'VarAggregation':
if 'MeanAggregation' in self.aggr_index:
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append((
'MeanAggregation',
self.aggr_index['MeanAggregation'],
))
elif 'SumAggregation' in self.aggr_index:
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append((
'SumAggregation',
self.aggr_index['SumAggregation'],
))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append(None)

elif cls == StdAggregation:
elif name == 'StdAggregation':
# Directly use output of `VarAggregation`:
if VarAggregation in self.aggr_index:
if 'VarAggregation' in self.aggr_index:
self.reduce_ops.append(None)
self.lookup_ops.append(
(VarAggregation, self.aggr_index[VarAggregation]))
elif MeanAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(MeanAggregation, self.aggr_index[MeanAggregation]))
elif SumAggregation in self.aggr_index:
self.reduce_ops.append(self.REDUCE[cls])
self.lookup_ops.append(
(SumAggregation, self.aggr_index[SumAggregation]))
self.lookup_ops.append((
'VarAggregation',
self.aggr_index['VarAggregation'],
))
elif 'MeanAggregation' in self.aggr_index:
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append((
'MeanAggregation',
self.aggr_index['MeanAggregation'],
))
elif 'SumAggregation' in self.aggr_index:
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append((
'SumAggregation',
self.aggr_index['SumAggregation'],
))
else:
self.reduce_ops.append(self.REDUCE[cls])
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append(None)

else:
self.reduce_ops.append(self.REDUCE[cls])
self.reduce_ops.append(self.REDUCE[name])
self.lookup_ops.append(None)

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
dim: int = -2) -> List[Tensor]:

# Assert two-dimensional input for now to simplify computation:
# TODO refactor this to support any dimension.
self.assert_index_present(index)
self.assert_two_dimensional_input(x, dim)

assert index is not None

if dim_size is None:
if ptr is not None:
dim_size = ptr.numel() - 1
else:
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0

count: Optional[Tensor] = None
mask: Optional[Tensor] = None
if self.need_degree:
count = x.new_zeros(dim_size)
count.scatter_add_(0, index, x.new_ones(x.size(0)))
Expand All @@ -207,6 +233,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
if reduce is None:
outs.append(None)
continue
assert isinstance(reduce, str)

src = x * x if reduce == 'pow_sum' else x
reduce = 'sum' if reduce == 'pow_sum' else reduce
Expand All @@ -223,71 +250,104 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
out = x.new_full((dim_size, num_feats), fill_value)
out.scatter_reduce_(0, index, src, reduce, include_self=True)
if fill_value != 0.0:
assert mask is not None
out = out.masked_fill(mask.view(-1, 1), 0.0)
outs.append(out)

#######################################################################

# Compute `MeanAggregation` first to be able to re-use it:
i = self.aggr_index.get(MeanAggregation)
i = self.aggr_index.get('MeanAggregation')
if i is not None:
assert count is not None

if self.lookup_ops[i] is None:
sum_ = outs[i]
else:
tmp_aggr, j = self.lookup_ops[i]
assert tmp_aggr == SumAggregation
lookup_op = self.lookup_ops[i]
assert lookup_op is not None
tmp_aggr, j = lookup_op
assert tmp_aggr == 'SumAggregation'

sum_ = outs[j]

assert sum_ is not None
outs[i] = sum_ / count

# Compute `VarAggregation` second to be able to re-use it:
i = self.aggr_index.get(VarAggregation)
i = self.aggr_index.get('VarAggregation')
if i is not None:
assert count is not None

if self.lookup_ops[i] is None:
mean = x.new_zeros(dim_size, num_feats)
mean.scatter_reduce_(0, index, x, 'sum', include_self=True)
mean = mean / count
else:
tmp_aggr, j = self.lookup_ops[i]
if tmp_aggr == SumAggregation:
mean = outs[j] / count
elif tmp_aggr == MeanAggregation:
lookup_op = self.lookup_ops[i]
assert lookup_op is not None
tmp_aggr, j = lookup_op

if tmp_aggr == 'SumAggregation':
sum_ = outs[j]
assert sum_ is not None
mean = sum_ / count
elif tmp_aggr == 'MeanAggregation':
mean = outs[j]
else:
raise NotImplementedError

pow_sum = outs[i]

assert pow_sum is not None
assert mean is not None
outs[i] = (pow_sum / count) - (mean * mean)

# Compute `StdAggregation` last:
i = self.aggr_index.get(StdAggregation)
i = self.aggr_index.get('StdAggregation')
if i is not None:
var = None
var: Optional[Tensor] = None
pow_sum: Optional[Tensor] = None
mean: Optional[Tensor] = None

if self.lookup_ops[i] is None:
pow_sum = outs[i]
mean = x.new_zeros(dim_size, num_feats)
mean.scatter_reduce_(0, index, x, 'sum', include_self=True)
assert count is not None
mean = mean / count
else:
tmp_aggr, j = self.lookup_ops[i]
if tmp_aggr == VarAggregation:
lookup_op = self.lookup_ops[i]
assert lookup_op is not None
tmp_aggr, j = lookup_op

if tmp_aggr == 'VarAggregation':
var = outs[j]
elif tmp_aggr == SumAggregation:
elif tmp_aggr == 'SumAggregation':
pow_sum = outs[i]
mean = outs[j] / count
elif tmp_aggr == MeanAggregation:
sum_ = outs[j]
assert sum_ is not None
assert count is not None
mean = sum_ / count
elif tmp_aggr == 'MeanAggregation':
pow_sum = outs[i]
mean = outs[j]
else:
raise NotImplementedError

if var is None:
assert pow_sum is not None
assert count is not None
assert mean is not None
var = (pow_sum / count) - (mean * mean)

outs[i] = (var.relu() + 1e-5).sqrt()

#######################################################################

out = torch.cat(outs, dim=-1)
vals: List[Tensor] = []
for out in outs:
assert out is not None
vals.append(out)

return out
return vals
Loading