Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Sep 20, 2023
1 parent 625dba2 commit 327b996
Showing 1 changed file with 113 additions and 66 deletions.
179 changes: 113 additions & 66 deletions docs/source/tutorial/multi_gpu_vanilla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,128 +2,175 @@ Multi-GPU Training in Pure PyTorch
==================================

For many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs.
This tutorial goes over mutli-GPU training
In particular, this tutorial introduces how to utilize :pyg:`PyG` with pure :pytorch:`PyTorch` for multi-GPU training via :class:`torch.nn.parallel.DistributedDataParallel`, without the need for any other third-party libraries (such as :lightning:`PyTorch Lightning`).
This tutorial goes over how to set up a multi-GPU training and inference pipeline in :pyg:`PyG` with pure :pytorch:`PyTorch` via :class:`torch.nn.parallel.DistributedDataParallel`, without the need for any other third-party libraries (such as :lightning:`PyTorch Lightning`).

To start, we can take a look at the `distributed sampling <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`__ example from :pyg:`PyG`.
This example shows how to use train a :class:`~torch_geometric.nn.models.GraphSAGE` GNN model on the :class:`~torch_geometric.datasets.Reddit` dataset.
This example uses the :class:`~torch_geometric.loader.NeighborLoader` with :class:`torch.nn.parallel.DistributedDataParallel` to scale-up training across all available GPU's on your machine.
In particular, this tutorials shows how to train a :class:`~torch_geometric.nn.models.GraphSAGE` GNN model on the :class:`~torch_geometric.datasets.Reddit` dataset.
For this, we utilize the :class:`~torch_geometric.loader.NeighborLoader` together with :class:`torch.nn.parallel.DistributedDataParallel` to scale-up training across all available GPUs.

.. note::
A runnable example of this tutorial can be found at `examples/multi_gpu/distributed_sampling.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`_.

Defining our Spawnable Runner
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Defining a Spawnable Runner
~~~~~~~~~~~~~~~~~~~~~~~~~~~

To begin, we want to define a spawnable runner function.
In our function, the world_size is the number of GPU's we will be using at once.
For each gpu, the process is labeled with a process ID we call rank.
To run multi-gpu we spawn a runner for each GPU in main.
Note that we initialize the dataset in main before spawning our runners to keep the dataset in shared memory.
`DistributedDataParallel (DDP) <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ implements data parallelism at the module level which can run across multiple machines.
Applications using DDP spawn multiple processes and create a single DDP instance per process.
DDP processes can be placed on the same machine or across machines.

To create a DDP module, we first need to set up process groups properly and define a spawnable runner function.
Here, the :obj:`world_size` corresponds to the number of GPUs we will be using at once.
For each GPU, the process is labeled with a process ID which we call :obj:`rank`.
:meth:`torch.multiprocessing.spawn` will take care of spawning :obj:`world_size` many processes:

.. code-block:: python
def run(rank, world_size, dataset):
...
from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp
def run(rank: int, world_size: int, dataset: Reddit):
pass
if __name__ == '__main__':
from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp
dataset = Reddit('../../data/Reddit')
dataset = Reddit('./data/Reddit')
world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')
mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)
Now we start to define our spawnable runner function.
Note that we initialize the dataset *before* spawning any processes.
With this, we only initialize the dataset once, and any data inside it will be automatically moved to shared memory via :obj:`torch.multiprocessing` such that processes do not need to work on replicas of the data.

With this, we can start to implement our spawnable runner function:

.. code-block:: python
import os
import torch.distributed as dist
import torch
def run(rank, world_size, dataset):
def run(rank: int, world_size: int, dataset: Reddit):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
data = dataset[0]
data = data.to(rank, 'x', 'y')
The first step above is initializing torch distributed.
Notice also that we move to features and labels to GPU for faster feature fetching.
Next we split training indices into :obj:`world_size` many chunks for each GPU:
The first step above is initializing :obj:`torch.distributed`.
More details can be found in `Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`_.

Next, we split training indices into :obj:`world_size` many chunks for each GPU, and initialize the :class:`~torch_geometric.loader.NeighborLoader` class to only operate on its specific chunk of the training set:

.. code-block:: python
from torch_geometric.loader import NeighborLoader
train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]
from torch_geometric.loader import NeighborLoader
kwargs = dict(batch_size=1024, num_workers=4, persistent_workers=True)
train_loader = NeighborLoader(data, input_nodes=train_idx,
num_neighbors=[25, 10], shuffle=True,
drop_last=True, **kwargs)
def run(rank: int, world_size: int, dataset: Reddit):
...
train_index = data.train_mask.nonzero().view(-1)
train_index = train_index.split(train_index.size(0) // world_size)[rank]
train_loader = NeighborLoader(
data,
input_nodes=train_index,
num_neighbors=[25, 10],
batch_size=1024,
num_workers=4,
shuffle=True,
)
Note that our run function is called on each rank, which means each rank's NeighborLoader is sampling from a reduced set of the training indices.
Note that our :meth:`run` function is called on each rank, which means that each rank holds a separate :class:`~torch_geometric.loader.NeighborLoader` instance.

We also create a single-hop evaluation neighbor loader. Note that we only do this on rank 0 since only one process needs to evaluate.
Similarly, we create a :class:`~torch_geometric.loader.NeighborLoader` instance for evaluation.
For simplicity, we only do this on rank :obj:`0` such that computation of metrics do not need to communicate across different processes.
We recommend to take a look at the `torchmetrics <https://torchmetrics.readthedocs.io/en/stable/>`_ package for distributed computation of metrics.

.. code-block:: python
if rank == 0:
val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
val_loader = NeighborLoader(data, num_neighbors=[25, 10], input_nodes=val_idx, shuffle=False, **kwargs)
def run(rank: int, world_size: int, dataset: Reddit):
...
Now that we have our data loaders defined, we initialize our model and wrap it in PyTorch's DistributedDataParallel.
This wrapper on our model manages communication between each rank and reduces loss gradients from each process before updating the models parameters across all ranks.
if rank == 0:
val_index = data.val_mask.nonzero().view(-1)
val_loader = NeighborLoader(
data,
input_nodes=val_index,
num_neighbors=[25, 10],
batch_size=1024,
num_workers=4,
shuffle=False,
)
Now that we have our data loaders defined, we initialize our :class:`~torch_geometric.nn.GraphSAGE` model and wrap it inside :pytorch:`PyTorch`'s :class:`~torch.nn.parallel.DistributedDataParallel`.
This wrapper on our model manages communication between each rank and reduces loss gradients from each process before updating the models parameters across all ranks:

.. code-block:: python
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn.models import GraphSAGE
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE
def run(rank: int, world_size: int, dataset: Reddit):
...
torch.manual_seed(12345)
model = GraphSAGE(in_channels=dataset.num_features,
hidden_channels=256,
num_layers=2,
out_channels=dataset.num_classes).to(rank)
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=256,
num_layers=2,
out_channels=dataset.num_classes,
).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
Now we set up our optimizer and define our training loop. Notice that we move the edge indices of each mini-batch to GPU while the features and labels are already on GPU.
Finally, we can set up our optimizer and define our training loop, which follows a similar flow as usual single GPU training loops - the actual magic of gradient and model weight synchronization across different processes will happen behind the scenes within :class:`~torch.nn.parallel.DistributedDataParallel`:

.. code-block:: python
import torch.nn.functional as F
import torch.nn.functional as F
def run(rank: int, world_size: int, dataset: Reddit):
...
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 21):
for epoch in range(1, 11):
model.train()
for batch in train_loader:
batch = batch.to_rank
optimizer.zero_grad()
out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
After each training epoch, we evaluate and report accuracies:
After each training epoch, we evaluate and report validation metrics.
As previously mentioned, we do this on a single GPU only.
To synchronize all processes and to ensure that model weights have been updated, we need to call :meth:`torch.distributed.barrier`:

.. code-block:: python
dist.barrier()
dist.barrier()
if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
if rank == 0:
model.eval()
count = correct = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(rank)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
correct += (pred == batch.y[:batch.batch_size]).sum()
count += batch.batch_size
print(f'Validation Accuracy: {correct/count:.4f}')
if rank == 0 and epoch % 5 == 0: # We evaluate on a single GPU for now
model.eval()
count = 0.0
correct = 0.0
with torch.no_grad():
for batch in val_loader:
out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]
correct += (out.argmax(dim=-1) == batch.y[:batch.batch_size]).sum()
count += batch.batch_size
print(f'Val Accuracy: {correct/count:.4f}')
dist.barrier()
dist.barrier()
After finishing training, we can clean up processes and destroy the process group via:

.. code-block:: python
dist.destroy_process_group()
dist.destroy_process_group()
Putting it all together gives a working multi-gpu example!
And that's it.
Putting it all together gives a working multi-GPU example that follows a similar training flow than single GPU training.
You can run the shown tutorial by yourself by looking at `examples/multi_gpu/distributed_sampling.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`_.

0 comments on commit 327b996

Please sign in to comment.