Skip to content

Commit

Permalink
update pytorch multi-gpu example, incorporate comments @samhita-alla @…
Browse files Browse the repository at this point in the history
…kumare3

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Aug 16, 2021
1 parent 3524ed1 commit f63efc1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 502 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs:

# store the hyperparameters' config in ``wandb``
wandb.config.update(json.loads(hp.to_json()))
print(wandb.config)

# set random seed
torch.manual_seed(hp.seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
Single Node, Multi GPU Training
--------------------------------
When you need to scale up model training in pytorch, you can use the :py:class:`pytorch:torch.nn.DataParallel` for
single node, multi-gpu/cpu training or :py:class`pytorch:torch.nn.parallel.DistributedDataParallel` for multi-node,
When you need to scale up model training in pytorch, you can use the :py:class:`~pytorch:torch.nn.DataParallel` for
single node, multi-gpu/cpu training or :py:class:`~pytorch:torch.nn.parallel.DistributedDataParallel` for multi-node,
multi-gpu training.
This tutorial will cover how to write a simple training script on the MNIST dataset that uses
:py:class`pytorch:torch.nn.parallel.DistributedDataParallel`, since this is the `recommended way <https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead>`__
of distributing your training workload. For training on a single node and gpu see
:ref:`this tutorial <sphx_glr_auto_case_studies_ml_training_mnist_classifier_pytorch_single_node.py>`.
For more information on distributed training, check out the
``DistributedDataParallel`` since its functionality is a superset of ``DataParallel``, supporting both single- and
multi-node training, and this is the `recommended way <https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead>`__
of distributing your training workload. Note, however, that this tutorial will only work for single-node, multi-gpu
settings.
For training on a single node and gpu see
:ref:`this tutorial <sphx_glr_auto_case_studies_ml_training_mnist_classifier_pytorch_single_node.py>`, and for more
information on distributed training, check out the
`pytorch documentation <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__.
"""

Expand All @@ -25,17 +28,23 @@
import torch
import torch.nn.functional as F
import wandb
from dataclasses_json import dataclass_json
from flytekit import Resources, task, workflow
from flytekit.types.file import PythonPickledFile
from torch import distributed as dist
from torch import nn, multiprocessing as mp, optim
from torchvision import datasets, transforms

# %%
# We'll re-use certain classes and functions from the
# :ref:`single node and gpu tutorial <sphx_glr_auto_case_studies_ml_training_mnist_classifier_pytorch_single_node.py>`
# such as the ``Net`` model architecture, ``Hyperparameters``, and ``log_test_predictions``.
from mnist_classifier.pytorch_single_node_and_gpu import Net, Hyperparameters, log_test_predictions

# %%
# Let's define some variables to be used later.
#
# ``WORLD_SIZE`` defines the total number of GPUs we want to use to distribute our training job.
# ``WORLD_SIZE`` defines the total number of GPUs we want to use to distribute our training job and ``DATA_DIR``
# specifies where the downloaded data should be written to.
WORLD_SIZE = 2
DATA_DIR = "./data"

Expand All @@ -58,28 +67,11 @@ def wandb_setup():
wandb.init(project="mnist-single-node-multi-gpu", entity=os.environ.get("WANDB_USERNAME", "my-user-name"))

# %%
# Creating the Network
# ====================
# Re-using the Network from the Single GPU Example
# ================================================
#
# We'll use the same neural network architecture as the one we define in the
# :ref:`single node tutorial <sphx_glr_auto_case_studies_ml_training_mnist_classifier_pytorch_single_node.py>`.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# :ref:`single node and gpu tutorial <sphx_glr_auto_case_studies_ml_training_mnist_classifier_pytorch_single_node_and_gpu.py>`.


# %%
Expand All @@ -95,6 +87,9 @@ def download_mnist(data_dir):
# %%
# The Data Loader
# ===============
#
# This function will be called in the training function to be distributed across all available GPUs. Note that
# we set ``download=False`` here to avoid race conditions as mentioned above.
def mnist_dataloader(data_dir, batch_size, train=True, distributed=False, rank=None, world_size=None, **kwargs):
dataset = datasets.MNIST(
data_dir,
Expand Down Expand Up @@ -122,8 +117,8 @@ def mnist_dataloader(data_dir, batch_size, train=True, distributed=False, rank=N
# Training
# ========
#
# We define a ``train`` function to enclose the training loop per epoch, i.e., this gets called for every successive epoch.
# Additionally, we log the loss and epoch progression, which can later be visualized in a ``wandb`` dashboard.
# We define a ``train`` function to enclose the training loop per epoch, and we log the loss and epoch progression,
# which can later be visualized in a ``wandb`` dashboard.
def train(model, rank, train_loader, optimizer, epoch, log_interval):
model.train()

Expand Down Expand Up @@ -154,35 +149,13 @@ def train(model, rank, train_loader, optimizer, epoch, log_interval):
wandb.log({"loss": loss, "epoch": epoch})


# %%
# We define a test logger function which will be called when we run the model on test dataset.
def log_test_predictions(images, labels, outputs, predicted, my_table, log_counter):
"""
Convenience function to log predictions for a batch of test images
"""

# obtain confidence scores for all classes
scores = F.softmax(outputs.data, dim=1)

# assign ids based on the order of the images
for i, (image, pred, label, score) in enumerate(
zip(*[x.cpu().numpy() for x in (images, predicted, labels, scores)])
):
# add required info to data table: id, image pixels, model's guess, true label, scores for all classes
my_table.add_data(f"{i}_{log_counter}", wandb.Image(image), pred, label, *score)
if i == LOG_IMAGES_PER_BATCH:
break



# %%
# Evaluation
# ==========
#
# We define a ``test`` function to test the model on the test dataset.
#
# We log ``accuracy``, ``test_loss``, and a ``wandb`` `table <https://docs.wandb.ai/guides/data-vis/log-tables>`__.
# The ``wandb`` table can help in depicting the model's performance in a structured format.
# We define a ``test`` function to test the model on the test dataset, logging ``accuracy``, and ``test_loss`` to a
# ``wandb`` `table <https://docs.wandb.ai/guides/data-vis/log-tables>`__, which helps us visualize the model's
# performance in a structured format.
def test(model, rank, test_loader):

model.eval()
Expand Down Expand Up @@ -228,36 +201,6 @@ def test(model, rank, test_loader):
return accuracy


# %%
# Hyperparameters
# ===============
#
# We define a few hyperparameters for training our model.
@dataclass_json
@dataclass
class Hyperparameters(object):
"""
Args:
backend: pytorch backend to use, e.g. "gloo" or "nccl"
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
"""

backend: str = dist.Backend.GLOO
sgd_momentum: float = 0.5
seed: int = 1
log_interval: int = 10
batch_size: int = 64
test_batch_size: int = 1000
epochs: int = 10
learning_rate: float = 0.01


# %%
# Training and Evaluating
# =======================
Expand All @@ -268,23 +211,40 @@ class Hyperparameters(object):
model_state=PythonPickledFile,
)

MODEL_FILE = "./mnist_cnn.pt"
ACCURACIES_FILE = "./mnist_cnn_accuracies.json"

# %%
# Setting up Distributed Training
# ===============================
#
# ``dist_setup`` is a helper function that instantiates a distributed environment. We're pointing all of the
# processes across all available GPUs to the address of the main process.

def dist_setup(rank, world_size, backend):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8888"
dist.init_process_group(backend, rank=rank, world_size=world_size)


def train_mnist(rank: int, world_size: int, hp: Hyperparameters) -> TrainingOutputs:
# %%
# These global variables point to the location of where to save the model and validation accuracies.
MODEL_FILE = "./mnist_cnn.pt"
ACCURACIES_FILE = "./mnist_cnn_accuracies.json"

# %%
# Then we define the ``train_mnist`` function. Note the conditionals that check for ``rank == 0``. These parts of the
# functions are only called in the main process, which is the ``0``th rank. The reason for this is that we only want the
# main process to perform certain actions such as:
#
# - log metrics via ``wandb``
# - save the trained model to disk
# - keep track of validation metrics

def train_mnist(rank: int, world_size: int, hp: Hyperparameters):

# store the hyperparameters' config in ``wandb``
if rank == 0:
wandb_setup()
wandb.config.update(json.loads(hp.to_json()))
print("wandb config:", wandb.config)

# set random seed
torch.manual_seed(hp.seed)
Expand Down Expand Up @@ -320,14 +280,18 @@ def train_mnist(rank: int, world_size: int, hp: Hyperparameters) -> TrainingOutp
# only compute validation metrics in the main process
if rank == 0:
accuracies.append(test(model, rank, test_data_loader))
dist.barrier() # wait for main process to complete validation before continuing training

# wait for the main process to complete validation before continuing the training process
dist.barrier()

if rank == 0:
wandb.finish() # this is important to tell to wandb that we're done logging metrics
# tell wandb that we're done logging metrics
wandb.finish()

# after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile`
# type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format
# of the serialized model as ``pickled``
# after training the model, we can simply save it to disk and return it from the Flyte
# task as a `flytekit.types.file.FlyteFile` type, which is the `PythonPickledFile`.
# `PythonPickledFile` is simply a decorator on the `FlyteFile` that records the format
# of the serialized model as `pickled`
print("Saving model")
torch.save(model.state_dict(), MODEL_FILE)

Expand All @@ -340,6 +304,7 @@ def train_mnist(rank: int, world_size: int, hp: Hyperparameters) -> TrainingOutp
dist.destroy_process_group() # clean up


# %%
# The output model using :py:func:`pytorch:torch.save` saves the `state_dict` as described
# `in pytorch docs <https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-and-loading-models>`_.
# A common convention is to have the ``.pt`` extension for the model file.
Expand All @@ -350,6 +315,18 @@ def train_mnist(rank: int, world_size: int, hp: Hyperparameters) -> TrainingOutp
# procured. Also, for the GPU Training to work, the Dockerfile needs to be built as explained in the
# :ref:`pytorch-dockerfile` section.

# %%
# Defining the ``task``
# =====================
#
# Next we define the flyte task that kicks off the distributed training process. Here we call the
# pytorch :ref:`multiprocessing <pytorch:torch.multiprocessing.spawn>` function to initiate a process on each
# available GPU. Since we're parallelizing the data, each process will contain a copy of the model and pytorch
# will handle syncing the weights across all processes on ``optimizer.step()`` calls.
#
# See `here <https://pytorch.org/tutorials/beginner/dist_overview.html>`_ to read more about pytorch distributed
# training.

@task(
retries=2,
cache=True,
Expand Down
Loading

0 comments on commit f63efc1

Please sign in to comment.