diff --git a/CHANGELOG.md b/CHANGELOG.md index f0004c387f85..8a651f33b5f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) - Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804)) ### Changed -- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051)) +- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052)) - Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908)) - Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819)) - Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815)) diff --git a/test/data/test_lightning_datamodule.py b/test/data/test_lightning_datamodule.py index 0dbaa71b23ff..269fee8fa237 100644 --- a/test/data/test_lightning_datamodule.py +++ b/test/data/test_lightning_datamodule.py @@ -13,7 +13,7 @@ ) from torch_geometric.nn import global_mean_pool from torch_geometric.sampler import BaseSampler, NeighborSampler -from torch_geometric.testing import onlyFullTest, withCUDA, withPackage +from torch_geometric.testing import onlyCUDA, onlyFullTest, withPackage from torch_geometric.testing.feature_store import MyFeatureStore from torch_geometric.testing.graph_store import MyGraphStore @@ -66,7 +66,7 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) -@withCUDA +@onlyCUDA @onlyFullTest @withPackage('pytorch_lightning') @pytest.mark.parametrize('strategy_type', [None, 'ddp_spawn']) @@ -157,7 +157,7 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) -@withCUDA +@onlyCUDA @onlyFullTest @withPackage('pytorch_lightning') @pytest.mark.parametrize('loader', ['full', 'neighbor']) @@ -252,7 +252,7 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.01) -@withCUDA +@onlyCUDA @onlyFullTest @withPackage('pytorch_lightning') def test_lightning_hetero_node_data(get_dataset): @@ -303,7 +303,7 @@ def sample_from_nodes(self, *args, **kwargs): assert isinstance(datamodule.neighbor_sampler, DummySampler) -@withCUDA +@onlyCUDA @onlyFullTest @withPackage('pytorch_lightning') def test_lightning_hetero_link_data(): diff --git a/test/nn/aggr/test_fused.py b/test/nn/aggr/test_fused.py index 33d053252554..d2b92b9accd4 100644 --- a/test/nn/aggr/test_fused.py +++ b/test/nn/aggr/test_fused.py @@ -29,7 +29,10 @@ def test_fused_aggregation(aggrs): out = torch.cat(aggr(x, index), dim=-1) expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1) - assert torch.allclose(out, expected) + assert torch.allclose(out, expected, atol=1e-6) + + jit = torch.jit.script(aggr) + assert torch.allclose(torch.cat(jit(x, index), dim=-1), out, atol=1e-6) out.mean().backward() assert x.grad is not None @@ -41,9 +44,6 @@ def test_fused_aggregation(aggrs): if __name__ == '__main__': import argparse import time - import warnings - - warnings.filterwarnings('ignore', '.*is in beta and the API may change.*') parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') diff --git a/test/nn/conv/test_fused_gat_conv.py b/test/nn/conv/test_fused_gat_conv.py index 3407fc087a6a..92709751c091 100644 --- a/test/nn/conv/test_fused_gat_conv.py +++ b/test/nn/conv/test_fused_gat_conv.py @@ -1,10 +1,10 @@ import torch from torch_geometric.nn import FusedGATConv -from torch_geometric.testing import withCUDA, withPackage +from torch_geometric.testing import onlyCUDA, withPackage -@withCUDA +@onlyCUDA @withPackage('dgNN') def test_fused_gat_conv(): device = torch.device('cuda') diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index da24bc943354..288f460135e0 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -9,7 +9,7 @@ from torch_geometric.loader import NeighborLoader from torch_geometric.nn import SAGEConv from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE -from torch_geometric.testing import withPackage, withPython +from torch_geometric.testing import onlyPython, withPackage out_dims = [None, 8] dropouts = [0.0, 0.5] @@ -138,7 +138,7 @@ def test_basic_gnn_inference(get_dataset, jk): assert 'n_id' not in data -@withPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet. +@onlyPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet. def test_packaging(): os.makedirs(torch.hub._get_torch_home(), exist_ok=True) diff --git a/test/nn/test_data_parallel.py b/test/nn/test_data_parallel.py index ce91edde3174..fd973e8c991e 100644 --- a/test/nn/test_data_parallel.py +++ b/test/nn/test_data_parallel.py @@ -3,10 +3,10 @@ from torch_geometric.data import Data from torch_geometric.nn import DataParallel -from torch_geometric.testing import withCUDA +from torch_geometric.testing import onlyCUDA -@withCUDA +@onlyCUDA @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs') def test_data_parallel(): module = DataParallel(None) diff --git a/test/profile/test_profile.py b/test/profile/test_profile.py index 36b2faf31359..aa398f5bce25 100644 --- a/test/profile/test_profile.py +++ b/test/profile/test_profile.py @@ -11,10 +11,10 @@ timeit, ) from torch_geometric.profile.profile import torch_profile -from torch_geometric.testing import onlyFullTest, withCUDA +from torch_geometric.testing import onlyCUDA, onlyFullTest -@withCUDA +@onlyCUDA @onlyFullTest def test_profile(get_dataset): dataset = get_dataset(name='PubMed') diff --git a/test/profile/test_utils.py b/test/profile/test_utils.py index 84675faea4fe..456b2c940de0 100644 --- a/test/profile/test_utils.py +++ b/test/profile/test_utils.py @@ -11,7 +11,7 @@ get_gpu_memory_from_nvidia_smi, get_model_size, ) -from torch_geometric.testing import withCUDA +from torch_geometric.testing import onlyCUDA def test_count_parameters(): @@ -40,7 +40,7 @@ def test_get_cpu_memory_from_gc(): assert new_mem - old_mem == 10 * 128 * 4 -@withCUDA +@onlyCUDA def test_get_cpu_memory_from_gc(): old_mem = get_gpu_memory_from_gc() _ = torch.randn(10, 128, device='cuda') @@ -48,7 +48,7 @@ def test_get_cpu_memory_from_gc(): assert new_mem - old_mem == 10 * 128 * 4 -@withCUDA +@onlyCUDA def test_get_gpu_memory_from_nvidia_smi(): free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2) assert free_mem >= 0 diff --git a/test/utils/test_scatter.py b/test/utils/test_scatter.py index 1478e3b3324b..9ffd5841d9b1 100644 --- a/test/utils/test_scatter.py +++ b/test/utils/test_scatter.py @@ -2,62 +2,57 @@ import torch import torch_scatter -from torch_geometric.testing import withPackage +from torch_geometric.testing import withCUDA from torch_geometric.utils import scatter -@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max', 'mul']) -def test_scatter(reduce): - torch.manual_seed(12345) - - src = torch.randn(8, 100, 32) +def test_scatter_validate(): + src = torch.randn(100, 32) index = torch.randint(0, 10, (100, ), dtype=torch.long) - out1 = scatter(src, index, dim=1, reduce=reduce) - out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce) - assert torch.allclose(out1, out2, atol=1e-6) + with pytest.raises(ValueError, match="must be one-dimensional"): + scatter(src, index.view(-1, 1)) + with pytest.raises(ValueError, match="must lay between 0 and 1"): + scatter(src, index, dim=2) -@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) -def test_pytorch_scatter_backward(reduce): - torch.manual_seed(12345) - - src = torch.randn(8, 100, 32).requires_grad_(True) - index = torch.randint(0, 10, (100, ), dtype=torch.long) + with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"): + scatter(src, index, reduce='std') - out = scatter(src, index, dim=1, reduce=reduce).relu() - assert src.grad is None - out.mean().backward() - assert src.grad is not None - - -@withPackage('torch>=1.12.0') -@pytest.mark.parametrize('reduce', ['min', 'max']) -def test_pytorch_scatter_inplace_backward(reduce): - torch.manual_seed(12345) +@withCUDA +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) +def test_scatter(reduce, device): + src = torch.randn(100, 16, device=device) + index = torch.randint(0, 8, (100, ), device=device) - src = torch.randn(8, 100, 32).requires_grad_(True) - index = torch.randint(0, 10, (100, ), dtype=torch.long) + out1 = scatter(src, index, dim=0, reduce=reduce) + out2 = torch_scatter.scatter(src, index, dim=0, reduce=reduce) + assert out1.device == device + assert torch.allclose(out1, out2, atol=1e-6) - out = scatter(src, index, dim=1, reduce=reduce).relu_() + jit = torch.jit.script(scatter) + out3 = jit(src, index, dim=0, reduce=reduce) + assert torch.allclose(out1, out3, atol=1e-6) - with pytest.raises(RuntimeError, match="modified by an inplace operation"): - out.mean().backward() + src = torch.randn(8, 100, 16, device=device) + out1 = scatter(src, index, dim=1, reduce=reduce) + out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce) + assert out1.device == device + assert torch.allclose(out1, out2, atol=1e-6) -@pytest.mark.parametrize('reduce', ['sum', 'add', 'min', 'max', 'mul']) -def test_scatter_with_out(reduce): - torch.manual_seed(12345) +@withCUDA +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) +def test_scatter_backward(reduce, device): + src = torch.randn(8, 100, 16, device=device, requires_grad=True) + index = torch.randint(0, 8, (100, ), device=device) - src = torch.randn(8, 100, 32) - index = torch.randint(0, 10, (100, ), dtype=torch.long) - out = torch.randn(8, 10, 32) + out = scatter(src, index, dim=1, reduce=reduce) - out1 = scatter(src, index, dim=1, out=out.clone(), reduce=reduce) - out2 = torch_scatter.scatter(src, index, dim=1, out=out.clone(), - reduce=reduce) - assert torch.allclose(out1, out2, atol=1e-6) + assert src.grad is None + out.mean().backward() + assert src.grad is not None if __name__ == '__main__': @@ -81,12 +76,9 @@ def test_scatter_with_out(reduce): import argparse import time - import warnings import torch_scatter - warnings.filterwarnings('ignore', '.*is in beta and the API may change.*') - parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') @@ -168,6 +160,34 @@ def test_scatter_with_out(reduce): if i >= num_warmups: t_backward += time.perf_counter() - t_start + print(f'torch_sparse forward: {t_forward:.4f}s') + if args.backward: + print(f'torch_sparse backward: {t_backward:.4f}s') + print('==============================') + + t_forward = t_backward = 0 + for i in range(num_warmups + num_steps): + x = torch.randn(num_edges, num_feats, device=args.device) + if args.backward: + x.requires_grad_(True) + + torch.cuda.synchronize() + t_start = time.perf_counter() + + out = scatter(x, index, dim=0, dim_size=num_nodes, reduce=aggr) + + torch.cuda.synchronize() + if i >= num_warmups: + t_forward += time.perf_counter() - t_start + + if args.backward: + t_start = time.perf_counter() + out.backward(out_grad) + + torch.cuda.synchronize() + if i >= num_warmups: + t_backward += time.perf_counter() - t_start + print(f'torch_sparse forward: {t_forward:.4f}s') if args.backward: print(f'torch_sparse backward: {t_backward:.4f}s') diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index cfb52742bd5c..4841fc0284a9 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -160,14 +160,14 @@ def assert_two_dimensional_input(self, x: Tensor, dim: int): def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, - dim: int = -2, reduce: str = 'add') -> Tensor: + dim: int = -2, reduce: str = 'sum') -> Tensor: if ptr is not None: ptr = expand_left(ptr, dim, dims=x.dim()) return segment_csr(x, ptr, reduce=reduce) assert index is not None - return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) + return scatter(x, index, dim, dim_size, reduce) def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py index 5431ca975cd6..cc513d143e5c 100644 --- a/torch_geometric/nn/aggr/fused.py +++ b/torch_geometric/nn/aggr/fused.py @@ -1,6 +1,5 @@ from typing import Dict, List, Optional, Tuple, Union -import torch from torch import Tensor from torch_geometric.nn.aggr.base import Aggregation @@ -14,6 +13,7 @@ VarAggregation, ) from torch_geometric.nn.resolver import aggregation_resolver +from torch_geometric.utils import scatter class FusedAggregation(Aggregation): @@ -33,21 +33,21 @@ class FusedAggregation(Aggregation): :class:`VarAggregation`, :class:`MeanAggregation` or :class:`SumAggregation` in case one of them is present as well. - In addition, temporary values such as the count per group index or the - mask for invalid rows are shared as well. + In addition, temporary values such as the count per group index are shared + as well. Benchmarking results on PyTorch 1.12 (summed over 1000 runs): +------------------------------+---------+---------+ | Aggregators | Vanilla | Fusion | +==============================+=========+=========+ - | :obj:`[sum, mean]` | 0.4019s | 0.1666s | + | :obj:`[sum, mean]` | 0.3325s | 0.1996s | +------------------------------+---------+---------+ - | :obj:`[sum, mean, min, max]` | 0.7841s | 0.4223s | + | :obj:`[sum, mean, min, max]` | 0.7139s | 0.5037s | +------------------------------+---------+---------+ - | :obj:`[sum, mean, var]` | 0.9711s | 0.3614s | + | :obj:`[sum, mean, var]` | 0.6849s | 0.3871s | +------------------------------+---------+---------+ - | :obj:`[sum, mean, var, std]` | 1.5994s | 0.3722s | + | :obj:`[sum, mean, var, std]` | 1.0955s | 0.3973s | +------------------------------+---------+---------+ Args: @@ -71,20 +71,13 @@ class FusedAggregation(Aggregation): StdAggregation, } - # All aggregations that require manual masking for invalid rows: - MASK_REQUIRED_AGGRS = { - MinAggregation, - MaxAggregation, - MulAggregation, - } - # Map aggregations to `reduce` options in `scatter` directives. REDUCE = { 'SumAggregation': 'sum', 'MeanAggregation': 'sum', - 'MinAggregation': 'amin', - 'MaxAggregation': 'amax', - 'MulAggregation': 'prod', + 'MinAggregation': 'min', + 'MaxAggregation': 'max', + 'MulAggregation': 'mul', 'VarAggregation': 'pow_sum', 'StdAggregation': 'pow_sum', } @@ -125,12 +118,6 @@ def __init__(self, aggrs: List[Union[Aggregation, str]]): if cls in self.DEGREE_BASED_AGGRS: self.need_degree = True - # Check whether we need to compute mask information: - self.requires_mask = False - for cls in aggr_classes: - if cls in self.MASK_REQUIRED_AGGRS: - self.requires_mask = True - # Determine which reduction to use for each aggregator: # An entry of `None` means that this operator re-uses intermediate # outputs from other aggregators. @@ -214,21 +201,11 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 count: Optional[Tensor] = None - mask: Optional[Tensor] = None if self.need_degree: count = x.new_zeros(dim_size) count.scatter_add_(0, index, x.new_ones(x.size(0))) - if self.requires_mask: - mask = count == 0 count = count.clamp_(min=1).view(-1, 1) - elif self.requires_mask: # Mask to set non-existing indicses to zero: - mask = x.new_ones(dim_size, dtype=torch.bool) - mask[index] = False - - num_feats = x.size(-1) - index = index.view(-1, 1).expand(-1, num_feats) - ####################################################################### outs: List[Optional[Tensor]] = [] @@ -240,32 +217,14 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, continue assert isinstance(reduce, str) - fill_value = 0.0 - if reduce == 'amin': - fill_value = float('inf') - elif reduce == 'amax': - fill_value = float('-inf') - elif reduce == 'prod': - fill_value = 1.0 - - # `include_self=True` + manual masking leads to faster runtime: - out = x.new_full((dim_size, num_feats), fill_value) - if reduce == 'pow_sum': - reduce = 'sum' if self.semi_grad: - with torch.no_grad(): - out.scatter_reduce_(0, index, x * x, reduce, - include_self=True) + out = scatter(x.detach() * x.detach(), index, 0, dim_size, + reduce='sum') else: - out.scatter_reduce_(0, index, x * x, reduce, - include_self=True) + out = scatter(x * x, index, 0, dim_size, reduce='sum') else: - out.scatter_reduce_(0, index, x, reduce, include_self=True) - - if fill_value != 0.0: - assert mask is not None - out = out.masked_fill(mask.view(-1, 1), 0.0) + out = scatter(x, index, 0, dim_size, reduce=reduce) outs.append(out) @@ -295,9 +254,8 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, assert count is not None if self.lookup_ops[i] is None: - mean = x.new_zeros(dim_size, num_feats) - mean.scatter_reduce_(0, index, x, 'sum', include_self=True) - mean = mean / count + sum_ = scatter(x, index, 0, dim_size, reduce='sum') + mean = sum_ / count else: lookup_op = self.lookup_ops[i] assert lookup_op is not None @@ -327,10 +285,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, if self.lookup_ops[i] is None: pow_sum = outs[i] - mean = x.new_zeros(dim_size, num_feats) - mean.scatter_reduce_(0, index, x, 'sum', include_self=True) + sum_ = scatter(x, index, 0, dim_size, reduce='sum') assert count is not None - mean = mean / count + mean = sum_ / count else: lookup_op = self.lookup_ops[i] assert lookup_op is not None diff --git a/torch_geometric/testing/__init__.py b/torch_geometric/testing/__init__.py index 4105bda72f93..888e4d96e576 100644 --- a/torch_geometric/testing/__init__.py +++ b/torch_geometric/testing/__init__.py @@ -1,11 +1,12 @@ -from .decorators import (is_full_test, onlyFullTest, onlyUnix, withPython, - withPackage, withCUDA) +from .decorators import (is_full_test, onlyFullTest, onlyUnix, onlyPython, + withPackage, onlyCUDA, withCUDA) __all__ = [ 'is_full_test', 'onlyFullTest', 'onlyUnix', - 'withPython', + 'onlyPython', 'withPackage', + 'onlyCUDA', 'withCUDA', ] diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index e54ee7d022bc..95e992cc9d5d 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -33,7 +33,7 @@ def onlyUnix(func: Callable) -> Callable: )(func) -def withPython(*args) -> Callable: +def onlyPython(*args) -> Callable: r"""A decorator to skip tests for any Python version not listed.""" def decorator(func: Callable) -> Callable: import pytest @@ -71,10 +71,21 @@ def decorator(func: Callable) -> Callable: return decorator -def withCUDA(func: Callable) -> Callable: +def onlyCUDA(func: Callable) -> Callable: r"""A decorator to skip tests if CUDA is not found.""" import pytest return pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", )(func) + + +def withCUDA(func: Callable): + r"""A decorator to test both on CPU and CUDA (if available).""" + import pytest + + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda:0')) + + return pytest.mark.parametrize('device', devices)(func) diff --git a/torch_geometric/utils/scatter.py b/torch_geometric/utils/scatter.py index 04e27e5741de..44905e6c0bec 100644 --- a/torch_geometric/utils/scatter.py +++ b/torch_geometric/utils/scatter.py @@ -1,7 +1,10 @@ +import warnings from typing import Optional import torch +import torch_scatter from torch import Tensor +from torch_scatter import scatter_max, scatter_min, scatter_mul major, minor, _ = torch.__version__.split('.', maxsplit=2) major, minor = int(major), int(minor) @@ -9,86 +12,85 @@ if has_pytorch112: # pragma: no cover - class ScatterHelpers: - @staticmethod - def broadcast(src: Tensor, other: Tensor, dim: int) -> Tensor: - dim = other.dim() + dim if dim < 0 else dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand(other.size()) - return src - - @staticmethod - def generate_out(src: Tensor, index: Tensor, dim: int, - dim_size: Optional[int]) -> Tensor: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() > 0: - size[dim] = int(index.max()) + 1 + warnings.filterwarnings('ignore', '.*is in beta and the API may change.*') + + def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor: + size = [1] * ref.dim() + size[dim] = -1 + return src.view(size).expand_as(ref) + + def scatter(src: Tensor, index: Tensor, dim: int = 0, + dim_size: Optional[int] = None, reduce: str = 'sum') -> Tensor: + + if index.dim() != 1: + raise ValueError(f"The `index` argument must be one-dimensional " + f"(got {index.dim()} dimensions)") + + dim = src.dim() + dim if dim < 0 else dim + + if dim < 0 or dim >= src.dim(): + raise ValueError(f"The `dim` argument must lay between 0 and " + f"{src.dim() - 1} (got {dim})") + + if dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + + # For now, we maintain various different code paths, based on whether + # the input requires gradients and whether it lays on the CPU/GPU. + # For example, `torch_scatter` is usually faster than + # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster + # on CPU. + # `torch.scatter_reduce` has a faster forward implementation for + # "min"/"max" reductions since it does not compute additional arg + # indices, but is therefore way slower in its backward implementation. + # More insights can be found in `test/utils/test_scatter.py`. + + size = list(src.size()) + size[dim] = dim_size + + # For "sum" and "mean" reduction, we make use of `scatter_add_`: + if reduce == 'sum' or reduce == 'add': + index = broadcast(index, src, dim) + return src.new_zeros(size).scatter_add_(dim, index, src) + + if reduce == 'mean': + count = src.new_zeros(dim_size) + count.scatter_add_(0, index, src.new_ones(src.size(dim))) + count = count.clamp_(min=1) + + index = broadcast(index, src, dim) + out = src.new_zeros(size).scatter_add_(dim, index, src) + + return out / broadcast(count, out, dim) + + # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or + # in case the input does not require gradients: + if reduce == 'min' or reduce == 'max': + if not src.is_cuda or not src.requires_grad: + index = broadcast(index, src, dim) + return src.new_zeros(size).scatter_reduce_( + dim, index, src, reduce=f'a{reduce}', include_self=False) + + if reduce == 'min': + return scatter_min(src, index, dim, dim_size=dim_size)[0] else: - size[dim] = 0 - return src.new_zeros(size) - - @staticmethod - def scatter_mean(src: Tensor, index: Tensor, dim: int, out: Tensor, - dim_size: Optional[int]) -> Tensor: - out.scatter_add_(dim, index, src) - - index_dim = dim - if index_dim < 0: - # `index_dim` counts axes from the begining (0) - index_dim = index_dim + src.dim() - if index.dim() <= index_dim: - # in case `index` was broadcasted, `count` scatter should be - # performed over the last axis - index_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = ScatterHelpers.generate_out(ones, index, index_dim, - dim_size) - count.scatter_add_(index_dim, index, ones) - count[count < 1] = 1 - count = ScatterHelpers.broadcast(count, out, dim) - if out.is_floating_point(): - out.true_divide_(count) + return scatter_max(src, index, dim, dim_size=dim_size)[0] + + # For "mul" reduction, we prefer `scatter_reduce_` on CPU: + if reduce == 'mul': + if not src.is_cuda: + index = broadcast(index, src, dim) + # We initialize with `one` here to match `scatter_mul` output: + return src.new_ones(size).scatter_reduce_( + dim, index, src, reduce='prod', include_self=True) else: - out.div_(count, rounding_mode='floor') - return out - - def scatter(src: Tensor, index: Tensor, dim: int = -1, - out: Optional[Tensor] = None, dim_size: Optional[int] = None, - reduce: str = 'sum') -> Tensor: - reduce = 'sum' if reduce == 'add' else reduce - reduce = 'prod' if reduce == 'mul' else reduce - reduce = 'amin' if reduce == 'min' else reduce - reduce = 'amax' if reduce == 'max' else reduce - - index = ScatterHelpers.broadcast(index, src, dim) - include_self = out is not None - - if out is None: # Generate `out` if not given: - out = ScatterHelpers.generate_out(src, index, dim, dim_size) - - # explicit usage of `torch.scatter_add_` and switching to - # `torch_scatter` implementation of mean algorithm comes with - # significant performance boost. - # TODO: use only `torch.scatter_reduce_` after performance issue will - # be fixed on the PyTorch side. - if reduce == 'mean': - return ScatterHelpers.scatter_mean(src, index, dim, out, dim_size) - elif reduce == 'sum': - return out.scatter_add_(dim, index, src) - return out.scatter_reduce_(dim, index, src, reduce, - include_self=include_self) + return scatter_mul(src, index, dim, dim_size=dim_size) + + raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'") else: - import torch_scatter - def scatter(src: Tensor, index: Tensor, dim: int = -1, - out: Optional[Tensor] = None, dim_size: Optional[int] = None, - reduce: str = 'sum') -> Tensor: - return torch_scatter.scatter(src, index, dim, out, dim_size, reduce) + def scatter(src: Tensor, index: Tensor, dim: int = 0, + dim_size: Optional[int] = None, reduce: str = 'sum') -> Tensor: + return torch_scatter.scatter(src, index, dim, dim_size=dim_size, + reduce=reduce)