Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized scatter operations for CPU/GPU #6052

Merged
merged 5 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804))
### Changed
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051))
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))
- Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908))
- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819))
- Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815))
Expand Down
10 changes: 5 additions & 5 deletions test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from torch_geometric.nn import global_mean_pool
from torch_geometric.sampler import BaseSampler, NeighborSampler
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
from torch_geometric.testing import onlyCUDA, onlyFullTest, withPackage
from torch_geometric.testing.feature_store import MyFeatureStore
from torch_geometric.testing.graph_store import MyGraphStore

Expand Down Expand Up @@ -66,7 +66,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
@pytest.mark.parametrize('strategy_type', [None, 'ddp_spawn'])
Expand Down Expand Up @@ -157,7 +157,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
@pytest.mark.parametrize('loader', ['full', 'neighbor'])
Expand Down Expand Up @@ -252,7 +252,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_node_data(get_dataset):
Expand Down Expand Up @@ -303,7 +303,7 @@ def sample_from_nodes(self, *args, **kwargs):
assert isinstance(datamodule.neighbor_sampler, DummySampler)


@withCUDA
@onlyCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data():
Expand Down
8 changes: 4 additions & 4 deletions test/nn/aggr/test_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def test_fused_aggregation(aggrs):
out = torch.cat(aggr(x, index), dim=-1)

expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1)
assert torch.allclose(out, expected)
assert torch.allclose(out, expected, atol=1e-6)

jit = torch.jit.script(aggr)
assert torch.allclose(torch.cat(jit(x, index), dim=-1), out, atol=1e-6)

out.mean().backward()
assert x.grad is not None
Expand All @@ -41,9 +44,6 @@ def test_fused_aggregation(aggrs):
if __name__ == '__main__':
import argparse
import time
import warnings

warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_fused_gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from torch_geometric.nn import FusedGATConv
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.testing import onlyCUDA, withPackage


@withCUDA
@onlyCUDA
@withPackage('dgNN')
def test_fused_gat_conv():
device = torch.device('cuda')
Expand Down
4 changes: 2 additions & 2 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE
from torch_geometric.testing import withPackage, withPython
from torch_geometric.testing import onlyPython, withPackage

out_dims = [None, 8]
dropouts = [0.0, 0.5]
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_basic_gnn_inference(get_dataset, jk):
assert 'n_id' not in data


@withPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet.
@onlyPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet.
def test_packaging():
os.makedirs(torch.hub._get_torch_home(), exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from torch_geometric.data import Data
from torch_geometric.nn import DataParallel
from torch_geometric.testing import withCUDA
from torch_geometric.testing import onlyCUDA


@withCUDA
@onlyCUDA
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs')
def test_data_parallel():
module = DataParallel(None)
Expand Down
4 changes: 2 additions & 2 deletions test/profile/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
timeit,
)
from torch_geometric.profile.profile import torch_profile
from torch_geometric.testing import onlyFullTest, withCUDA
from torch_geometric.testing import onlyCUDA, onlyFullTest


@withCUDA
@onlyCUDA
@onlyFullTest
def test_profile(get_dataset):
dataset = get_dataset(name='PubMed')
Expand Down
6 changes: 3 additions & 3 deletions test/profile/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_gpu_memory_from_nvidia_smi,
get_model_size,
)
from torch_geometric.testing import withCUDA
from torch_geometric.testing import onlyCUDA


def test_count_parameters():
Expand Down Expand Up @@ -40,15 +40,15 @@ def test_get_cpu_memory_from_gc():
assert new_mem - old_mem == 10 * 128 * 4


@withCUDA
@onlyCUDA
def test_get_cpu_memory_from_gc():
old_mem = get_gpu_memory_from_gc()
_ = torch.randn(10, 128, device='cuda')
new_mem = get_gpu_memory_from_gc()
assert new_mem - old_mem == 10 * 128 * 4


@withCUDA
@onlyCUDA
def test_get_gpu_memory_from_nvidia_smi():
free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2)
assert free_mem >= 0
Expand Down
106 changes: 63 additions & 43 deletions test/utils/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,57 @@
import torch
import torch_scatter

from torch_geometric.testing import withPackage
from torch_geometric.testing import withCUDA
from torch_geometric.utils import scatter


@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max', 'mul'])
def test_scatter(reduce):
torch.manual_seed(12345)

src = torch.randn(8, 100, 32)
def test_scatter_validate():
src = torch.randn(100, 32)
index = torch.randint(0, 10, (100, ), dtype=torch.long)

out1 = scatter(src, index, dim=1, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)
assert torch.allclose(out1, out2, atol=1e-6)
with pytest.raises(ValueError, match="must be one-dimensional"):
scatter(src, index.view(-1, 1))

with pytest.raises(ValueError, match="must lay between 0 and 1"):
scatter(src, index, dim=2)

@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_pytorch_scatter_backward(reduce):
torch.manual_seed(12345)

src = torch.randn(8, 100, 32).requires_grad_(True)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"):
scatter(src, index, reduce='std')

out = scatter(src, index, dim=1, reduce=reduce).relu()

assert src.grad is None
out.mean().backward()
assert src.grad is not None


@withPackage('torch>=1.12.0')
@pytest.mark.parametrize('reduce', ['min', 'max'])
def test_pytorch_scatter_inplace_backward(reduce):
torch.manual_seed(12345)
@withCUDA
@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_scatter(reduce, device):
src = torch.randn(100, 16, device=device)
index = torch.randint(0, 8, (100, ), device=device)

src = torch.randn(8, 100, 32).requires_grad_(True)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
out1 = scatter(src, index, dim=0, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=0, reduce=reduce)
assert out1.device == device
assert torch.allclose(out1, out2, atol=1e-6)

out = scatter(src, index, dim=1, reduce=reduce).relu_()
jit = torch.jit.script(scatter)
out3 = jit(src, index, dim=0, reduce=reduce)
assert torch.allclose(out1, out3, atol=1e-6)

with pytest.raises(RuntimeError, match="modified by an inplace operation"):
out.mean().backward()
src = torch.randn(8, 100, 16, device=device)
out1 = scatter(src, index, dim=1, reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)
assert out1.device == device
assert torch.allclose(out1, out2, atol=1e-6)


@pytest.mark.parametrize('reduce', ['sum', 'add', 'min', 'max', 'mul'])
def test_scatter_with_out(reduce):
torch.manual_seed(12345)
@withCUDA
@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])
def test_scatter_backward(reduce, device):
src = torch.randn(8, 100, 16, device=device, requires_grad=True)
index = torch.randint(0, 8, (100, ), device=device)

src = torch.randn(8, 100, 32)
index = torch.randint(0, 10, (100, ), dtype=torch.long)
out = torch.randn(8, 10, 32)
out = scatter(src, index, dim=1, reduce=reduce)

out1 = scatter(src, index, dim=1, out=out.clone(), reduce=reduce)
out2 = torch_scatter.scatter(src, index, dim=1, out=out.clone(),
reduce=reduce)
assert torch.allclose(out1, out2, atol=1e-6)
assert src.grad is None
out.mean().backward()
assert src.grad is not None


if __name__ == '__main__':
Expand All @@ -81,12 +76,9 @@ def test_scatter_with_out(reduce):

import argparse
import time
import warnings

import torch_scatter

warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
Expand Down Expand Up @@ -168,6 +160,34 @@ def test_scatter_with_out(reduce):
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'torch_sparse forward: {t_forward:.4f}s')
if args.backward:
print(f'torch_sparse backward: {t_backward:.4f}s')
print('==============================')

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = scatter(x, index, dim=0, dim_size=num_nodes, reduce=aggr)

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'torch_sparse forward: {t_forward:.4f}s')
if args.backward:
print(f'torch_sparse backward: {t_backward:.4f}s')
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def assert_two_dimensional_input(self, x: Tensor, dim: int):

def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'add') -> Tensor:
dim: int = -2, reduce: str = 'sum') -> Tensor:

if ptr is not None:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

assert index is not None
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)
return scatter(x, index, dim, dim_size, reduce)

def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
Expand Down
Loading