Skip to content

Commit

Permalink
Improvements to multinode papers100m default hyperparams and adding e…
Browse files Browse the repository at this point in the history
…val on all ranks (#8823)

> using main branch of PyG (16 GraceHopper nodes):
> Val Acc: 0.4546
> Test Acc: 0.3770
> 
> using PR branch (2 GraceHopper nodes due to availability):
> Validation Accuracy: 51.1759%
> Test Accuracy: 44.5692%

PR ready

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent cfdb4ce commit a1251ab
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823))
- Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978))
- Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021))
- Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009))
Expand Down
113 changes: 57 additions & 56 deletions examples/multi_gpu/papers100m_gcn_multinode.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
"""Multi-node multi-GPU example on ogbn-papers100m.
To run:
Example way to run using srun:
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
--container-name=cont --container-image=<image_url> \
--container-mounts=/ogb-papers100m/:/workspace/dataset
python3 path_to_script.py
"""
import os
import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GCN


def get_num_workers() -> int:
Expand All @@ -31,21 +33,7 @@ def get_num_workers() -> int:
return num_workers


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x


def run(world_size, data, split_idx, model):
def run(world_size, data, split_idx, model, acc, wall_clock_start):
local_id = int(os.environ['LOCAL_RANK'])
rank = torch.distributed.get_rank()
torch.cuda.set_device(local_id)
Expand All @@ -54,38 +42,48 @@ def run(world_size, data, split_idx, model):
print(f'Using {nprocs} GPUs...')

split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size,
dim=0,
)[rank].clone()
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
split_idx['valid'] = split_idx['valid'].split(
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
split_idx['test'] = split_idx['test'].split(
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()

model = DistributedDataParallel(model.to(device), device_ids=[local_id])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=5e-4)

kwargs = dict(
data=data,
batch_size=128,
batch_size=1024,
num_workers=get_num_workers(),
num_neighbors=[50, 50],
num_neighbors=[30, 30],
)

train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
drop_last=True,
**kwargs,
)
if rank == 0:
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

val_steps = 1000
warmup_steps = 100
acc = acc.to(device)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time)=", prep_time,
"seconds")
print("Beginning training...")

for epoch in range(1, 4):
for epoch in range(1, 21):
model.train()
for i, batch in enumerate(train_loader):
if i == warmup_steps:
torch.cuda.synchronize()
start = time.time()
batch = batch.to(device)
optimizer.zero_grad()
Expand All @@ -98,53 +96,56 @@ def run(world_size, data, split_idx, model):
if rank == 0 and i % 10 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')

dist.barrier()
torch.cuda.synchronize()
if rank == 0:
sec_per_iter = (time.time() - start) / (i - warmup_steps)
sec_per_iter = (time.time() - start) / (i + 1 - warmup_steps)
print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter")

@torch.no_grad()
def test(loader: NeighborLoader, num_steps: Optional[int] = None):
model.eval()
total_correct = total_examples = 0
for i, batch in enumerate(val_loader):
if i >= val_steps:
for j, batch in enumerate(loader):
if num_steps is not None and j >= num_steps:
break
if i == warmup_steps:
start = time.time()

batch = batch.to(device)
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
acc(out, y)
acc_sum = acc.compute()
return acc_sum

total_correct += int((pred == y).sum())
total_examples += y.size(0)
eval_acc = test(val_loader, num_steps=val_steps)
if rank == 0:
print(f"Val Accuracy: {eval_acc:.4f}%", )

print(f"Val Acc: {total_correct / total_examples:.4f}")
sec_per_iter = (time.time() - start) / (i - warmup_steps)
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")
acc.reset()
dist.barrier()

test_acc = test(test_loader)
if rank == 0:
model.eval()
total_correct = total_examples = 0
for i, batch in enumerate(test_loader):
batch = batch.to(device)
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
print(f"Test Accuracy: {test_acc:.4f}%", )

total_correct += int((pred == y).sum())
total_examples += y.size(0)
print(f"Test Acc: {total_correct / total_examples:.4f}")
dist.barrier()
acc.reset()
torch.cuda.synchronize()

if rank == 0:
total_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total Program Runtime (total_time) =", total_time, "seconds")
print("total_time - prep_time =", total_time - prep_time, "seconds")


if __name__ == '__main__':
wall_clock_start = time.perf_counter()
# Setup multi-node:
torch.distributed.init_process_group("nccl")
nprocs = dist.get_world_size()
assert dist.is_initialized(), "Distributed cluster not initialized"
dataset = PygNodePropPredDataset(name='ogbn-papers100M')
split_idx = dataset.get_idx_split()
model = GCN(dataset.num_features, 64, dataset.num_classes)

run(nprocs, dataset[0], split_idx, model)
model = GCN(dataset.num_features, 256, 2, dataset.num_classes)
acc = Accuracy(task="multiclass", num_classes=dataset.num_classes)
data = dataset[0]
data.y = data.y.reshape(-1)
run(nprocs, data, split_idx, model, acc, wall_clock_start)

0 comments on commit a1251ab

Please sign in to comment.