Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark for scatter_reduce #6051

Merged
merged 4 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804))
### Changed
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051))
- Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908))
- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819))
- Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815))
Expand Down
19 changes: 13 additions & 6 deletions test/nn/aggr/test_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,35 @@ def test_fused_aggregation(aggrs):
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

num_nodes, num_edges, num_feats = 1000, 50000, 64

num_warmups, num_steps = 500, 1000
if args.device == 'cpu':
num_warmups, num_steps = num_warmups // 10, num_steps // 10

aggrs = ['sum', 'mean', 'max', 'std']
print(f'Aggregators: {", ".join(aggrs)}')
print('=========================')
aggrs = [aggregation_resolver(aggr) for aggr in aggrs]
fused_aggr = FusedAggregation(aggrs)

index = torch.randint(num_nodes, (num_edges, ), device='cuda')
out_grad = torch.randn(num_nodes, len(aggrs) * num_feats, device='cuda')
index = torch.randint(num_nodes, (num_edges, ), device=args.device)
out_grad = torch.randn(num_nodes,
len(aggrs) * num_feats, device=args.device)

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device='cuda')
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)
torch.cuda.synchronize()

torch.cuda.synchronize()
t_start = time.perf_counter()

outs = [aggr(x, index, dim_size=num_nodes) for aggr in aggrs]
out = torch.cat(outs, dim=-1)

Expand All @@ -91,12 +97,13 @@ def test_fused_aggregation(aggrs):

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device='cuda')
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)
torch.cuda.synchronize()

torch.cuda.synchronize()
t_start = time.perf_counter()

out = torch.cat(fused_aggr(x, index, dim_size=num_nodes), dim=-1)

torch.cuda.synchronize()
Expand Down
115 changes: 115 additions & 0 deletions test/utils/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,118 @@ def test_scatter_with_out(reduce):
out2 = torch_scatter.scatter(src, index, dim=1, out=out.clone(),
reduce=reduce)
assert torch.allclose(out1, out2, atol=1e-6)


if __name__ == '__main__':
# Insights on GPU:
# ================
# * "sum": Prefer `scatter_add_` implementation
# * "mean": Prefer manual implementation via `scatter_add_` + `count`
# * "min"/"max":
# * Prefer `scatter_reduce_` implementation without gradients
# * Prefer `torch_sparse` implementation with gradients
# * "mul": Prefer `torch_sparse` implementation
#
# Insights on CPU:
# ================
# * "sum": Prefer `scatter_add_` implementation
# * "mean": Prefer manual implementation via `scatter_add_` + `count`
# * "min"/"max": Prefer `scatter_reduce_` implementation
# * "mul" (probably not worth branching for this):
# * Prefer `scatter_reduce_` implementation without gradients
# * Prefer `torch_sparse` implementation with gradients

import argparse
import time
import warnings

import torch_scatter

warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

num_nodes, num_edges, num_feats = 1000, 50000, 64

num_warmups, num_steps = 500, 1000
if args.device == 'cpu':
num_warmups, num_steps = num_warmups // 10, num_steps // 10

index = torch.randint(num_nodes - 5, (num_edges, ), device=args.device)
out_grad = torch.randn(num_nodes, num_feats, device=args.device)

aggrs = ['sum', 'mean', 'min', 'max', 'mul']
for aggr in aggrs:
print(f'Aggregator: {aggr}')
print('==============================')

reduce = aggr
if reduce == 'min' or reduce == 'max':
reduce = f'a{aggr}' # `amin` or `max`
elif reduce == 'mul':
reduce = 'prod'

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = x.new_zeros((num_nodes, num_feats))
include_self = reduce in ['sum', 'mean']
broadcasted_index = index.view(-1, 1).expand(-1, num_feats)
out.scatter_reduce_(0, broadcasted_index, x, reduce,
include_self=include_self)

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'PyTorch forward: {t_forward:.4f}s')
if args.backward:
print(f'PyTorch backward: {t_backward:.4f}s')
print('==============================')

t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
x = torch.randn(num_edges, num_feats, device=args.device)
if args.backward:
x.requires_grad_(True)

torch.cuda.synchronize()
t_start = time.perf_counter()

out = torch_scatter.scatter(x, index, dim=0, dim_size=num_nodes,
reduce=aggr)

torch.cuda.synchronize()
if i >= num_warmups:
t_forward += time.perf_counter() - t_start

if args.backward:
t_start = time.perf_counter()
out.backward(out_grad)

torch.cuda.synchronize()
if i >= num_warmups:
t_backward += time.perf_counter() - t_start

print(f'torch_sparse forward: {t_forward:.4f}s')
if args.backward:
print(f'torch_sparse backward: {t_backward:.4f}s')

print()