Skip to content

Commit

Permalink
Optimized scatter operations for CPU/GPU (#6052)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 24, 2022
1 parent 29615b9 commit 086a20e
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 210 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804))
### Changed
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051))
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))
- Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908))
- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819))
- Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815))
Expand Down
10 changes: 5 additions & 5 deletions test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from torch_geometric.nn import global_mean_pool
from torch_geometric.sampler import BaseSampler, NeighborSampler
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
from torch_geometric.testing import onlyCUDA, onlyFullTest, withPackage
from torch_geometric.testing.feature_store import MyFeatureStore
from torch_geometric.testing.graph_store import MyGraphStore

Expand Down Expand Up @@ -66,7 +66,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
@pytest.mark.parametrize('strategy_type', [None, 'ddp_spawn'])
Expand Down Expand Up @@ -157,7 +157,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
@pytest.mark.parametrize('loader', ['full', 'neighbor'])
Expand Down Expand Up @@ -252,7 +252,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_node_data(get_dataset):
Expand Down Expand Up @@ -303,7 +303,7 @@ def sample_from_nodes(self, *args, **kwargs):
assert isinstance(datamodule.neighbor_sampler, DummySampler)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data():
Expand Down
8 changes: 4 additions & 4 deletions test/nn/aggr/test_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def test_fused_aggregation(aggrs):
out = torch.cat(aggr(x, index), dim=-1)

expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1)
assert torch.allclose(out, expected)
assert torch.allclose(out, expected, atol=1e-6)

jit = torch.jit.script(aggr)
assert torch.allclose(torch.cat(jit(x, index), dim=-1), out, atol=1e-6)

out.mean().backward()
assert x.grad is not None
Expand All @@ -41,9 +44,6 @@ def test_fused_aggregation(aggrs):
if __name__ == '__main__':
import argparse
import time
import warnings

warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_fused_gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from torch_geometric.nn import FusedGATConv
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.testing import onlyCUDA, withPackage


@withCUDA
@onlyCUDA
@withPackage('dgNN')
def test_fused_gat_conv():
device = torch.device('cuda')
Expand Down
4 changes: 2 additions & 2 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE
from torch_geometric.testing import withPackage, withPython
from torch_geometric.testing import onlyPython, withPackage

out_dims = [None, 8]
dropouts = [0.0, 0.5]
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_basic_gnn_inference(get_dataset, jk):
assert 'n_id' not in data


@withPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet.
@onlyPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet.
def test_packaging():
os.makedirs(torch.hub._get_torch_home(), exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from torch_geometric.data import Data
from torch_geometric.nn import DataParallel
from torch_geometric.testing import withCUDA
from torch_geometric.testing import onlyCUDA


@withCUDA
@onlyCUDA
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs')
def test_data_parallel():
module = DataParallel(None)
Expand Down
4 changes: 2 additions & 2 deletions test/profile/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
timeit,
)
from torch_geometric.profile.profile import torch_profile
from torch_geometric.testing import onlyFullTest, withCUDA
from torch_geometric.testing import onlyCUDA, onlyFullTest


@withCUDA
@onlyCUDA
@onlyFullTest
def test_profile(get_dataset):
dataset = get_dataset(name='PubMed')
Expand Down
6 changes: 3 additions & 3 deletions test/profile/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_gpu_memory_from_nvidia_smi,
get_model_size,
)
from torch_geometric.testing import withCUDA
from torch_geometric.testing import onlyCUDA


def test_count_parameters():
Expand Down Expand Up @@ -40,15 +40,15 @@ def test_get_cpu_memory_from_gc():
assert new_mem - old_mem == 10 * 128 * 4


@withCUDA
@onlyCUDA
def test_get_cpu_memory_from_gc():
old_mem = get_gpu_memory_from_gc()
_ = torch.randn(10, 128, device='cuda')
new_mem = get_gpu_memory_from_gc()
assert new_mem - old_mem == 10 * 128 * 4


@withCUDA
@onlyCUDA
def test_get_gpu_memory_from_nvidia_smi():
free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2)
assert free_mem >= 0
Expand Down
106 changes: 63 additions & 43 deletions test/utils/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,57 @@
import torch
import torch_scatter

from torch_geometric.testing import withPackage
from torch_geometric.testing import withCUDA
from torch_geometric.utils import scatter


@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max', 'mul'])
def test_scatter(reduce):
torch.manual_seed(12345)

src = torch.randn(8, 100, 32)
def test_scatter_validate():
src = torch.randn(100, 32)
index = torch.randint(0, 10, (100, ), dtype=torch.long)

out1 = scatter(src, index, dim=1, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)
assert torch.allclose(out1, out2, atol=1e-6)
with pytest.raises(ValueError, match="must be one-dimensional"):
scatter(src, index.view(-1, 1))

with pytest.raises(ValueError, match="must lay between 0 and 1"):
scatter(src, index, dim=2)

@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_pytorch_scatter_backward(reduce):
torch.manual_seed(12345)

src = torch.randn(8, 100, 32).requires_grad_(True)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"):
scatter(src, index, reduce='std')

out = scatter(src, index, dim=1, reduce=reduce).relu()

assert src.grad is None
out.mean().backward()
assert src.grad is not None


@withPackage('torch>=1.12.0')
@pytest.mark.parametrize('reduce', ['min', 'max'])
def test_pytorch_scatter_inplace_backward(reduce):
torch.manual_seed(12345)
@withCUDA
@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_scatter(reduce, device):
src = torch.randn(100, 16, device=device)
index = torch.randint(0, 8, (100, ), device=device)

src = torch.randn(8, 100, 32).requires_grad_(True)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
out1 = scatter(src, index, dim=0, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=0, reduce=reduce)
assert out1.device == device
assert torch.allclose(out1, out2, atol=1e-6)

out = scatter(src, index, dim=1, reduce=reduce).relu_()
jit = torch.jit.script(scatter)
out3 = jit(src, index, dim=0, reduce=reduce)
assert torch.allclose(out1, out3, atol=1e-6)

with pytest.raises(RuntimeError, match="modified by an inplace operation"):
out.mean().backward()
src = torch.randn(8, 100, 16, device=device)
out1 = scatter(src, index, dim=1, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)
assert out1.device == device
assert torch.allclose(out1, out2, atol=1e-6)


@pytest.mark.parametrize('reduce', ['sum', 'add', 'min', 'max', 'mul'])
def test_scatter_with_out(reduce):
torch.manual_seed(12345)
@withCUDA
@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_scatter_backward(reduce, device):
src = torch.randn(8, 100, 16, device=device, requires_grad=True)
index = torch.randint(0, 8, (100, ), device=device)

src = torch.randn(8, 100, 32)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
out = torch.randn(8, 10, 32)
out = scatter(src, index, dim=1, reduce=reduce)

out1 = scatter(src, index, dim=1, out=out.clone(), reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, out=out.clone(),
reduce=reduce)
assert torch.allclose(out1, out2, atol=1e-6)
assert src.grad is None
out.mean().backward()
assert src.grad is not None


if __name__ == '__main__':
Expand All @@ -81,12 +76,9 @@ def test_scatter_with_out(reduce):

import argparse
import time
import warnings

import torch_scatter

warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
Expand Down Expand Up @@ -168,6 +160,34 @@ def test_scatter_with_out(reduce):
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'torch_sparse forward: {t_forward:.4f}s')
if args.backward:
print(f'torch_sparse backward: {t_backward:.4f}s')
print('==============================')

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = scatter(x, index, dim=0, dim_size=num_nodes, reduce=aggr)

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'torch_sparse forward: {t_forward:.4f}s')
if args.backward:
print(f'torch_sparse backward: {t_backward:.4f}s')
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def assert_two_dimensional_input(self, x: Tensor, dim: int):

def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'add') -> Tensor:
dim: int = -2, reduce: str = 'sum') -> Tensor:

if ptr is not None:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

assert index is not None
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)
return scatter(x, index, dim, dim_size, reduce)

def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
Expand Down
Loading

0 comments on commit 086a20e

Please sign in to comment.