Skip to content

Commit

Permalink
update pytorch multi-gpu example, incorporate comments @samhita-alla @…
Browse files Browse the repository at this point in the history
…kumare3

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Aug 16, 2021
1 parent 871a714 commit 95175dc
Show file tree
Hide file tree
Showing 3 changed files with 466 additions and 137 deletions.
14 changes: 9 additions & 5 deletions cookbook/case_studies/ml_training/mnist_classifier/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks

WORKDIR /root
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

# Give your wandb API key. Get it from https://wandb.ai/authorize.
# ENV WANDB_API_KEY your-api-key
# Set your wandb API key and user name. Get the API key from https://wandb.ai/authorize.
# ENV WANDB_API_KEY <api_key>
# ENV WANDB_USERNAME <user_name>

# Install the AWS cli for AWS support
RUN pip install awscli

ENV VENV /opt/venv

# Virtual environment
ENV VENV /opt/venv
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

Expand All @@ -25,6 +25,10 @@ RUN pip install -r /root/requirements.txt
# Copy the actual code
COPY mnist_classifier/ /root/mnist_classifier/

# Copy the makefile targets to expose on the container. This makes it easier to register.
COPY in_container.mk /root/Makefile
COPY mnist_classifier/sandbox.config /root

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Single GPU Training
-------------------
Single Node, Single GPU Training
--------------------------------
Training a model on a single node on one GPU is as trivial as writing any Flyte task and simply setting the GPU to ``1``.
As long as the Docker image is built correctly with the right version of the GPU drivers and the Flyte backend is
Expand Down Expand Up @@ -31,29 +31,22 @@
from torchvision import datasets, transforms

# %%
# Let's define some variables to be used later.
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))

# %%
# The following variables are specific to ``wandb``:
# Let's define some variables to be used later. The following variables are specific to ``wandb``:
#
# - ``NUM_BATCHES_TO_LOG``: Number of batches to log from the test data for each test step
# - ``LOG_IMAGES_PER_BATCH``: Number of images to log per test batch
NUM_BATCHES_TO_LOG = 10
LOG_IMAGES_PER_BATCH = 32

# %%
# If running remotely, copy your ``wandb`` API key to the Dockerfile. Next, login to ``wandb``.
# You can disable this if you're already logged in on your local machine.
wandb.login()

# %%
# Next, we initialize the ``wandb`` project.
# If running remotely, copy your ``wandb`` API key to the Dockerfile under the environment variable ``WANDB_API_KEY``.
# This function logs into ``wandb`` and initializes the project. If you built your Docker image with the
# ``WANDB_USERNAME``, this will work. Otherwise, replace ``my-user-name`` with your ``wandb`` user name.
#
# .. admonition:: MUST DO!
#
# Replace ``entity`` value with your username.
wandb.init(project="mnist-single-node", entity="your-user-name")
# We'll call this function in the ``pytorch_mnist_task`` defined below.
def wandb_setup():
wandb.login()
wandb.init(project="mnist-single-node-single-gpu", entity=os.environ.get("WANDB_USERNAME", "my-user-name"))

# %%
# Creating the Network
Expand Down Expand Up @@ -81,6 +74,26 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


# %%
# The Data Loader
# ===============

def mnist_dataloader(batch_size, train=True, **kwargs):
return torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=train,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=batch_size,
shuffle=True,
**kwargs,
)


# %%
# Training
# ========
Expand All @@ -95,24 +108,12 @@ def train(model, device, train_loader, optimizer, epoch, log_interval):

# loop through the training batches
for batch_idx, (data, target) in enumerate(train_loader):

# device conversion
data, target = data.to(device), target.to(device)

# clear gradient
optimizer.zero_grad()

# forward pass
output = model(data)

# compute loss
loss = F.nll_loss(output, target)

# propagate the loss backward
loss.backward()

# update the model parameters
optimizer.step()
data, target = data.to(device), target.to(device) # device conversion
optimizer.zero_grad() # clear gradient
output = model(data) # forward pass
loss = F.nll_loss(output, target) # compute loss
loss.backward() # propagate the loss backward
optimizer.step() # update the model parameters

if batch_idx % log_interval == 0:
print(
Expand All @@ -133,26 +134,19 @@ def train(model, device, train_loader, optimizer, epoch, log_interval):
# We define a test logger function which will be called when we run the model on test dataset.
def log_test_predictions(images, labels, outputs, predicted, my_table, log_counter):
"""
Convenience funtion to log predictions for a batch of test images
Convenience function to log predictions for a batch of test images
"""

# obtain confidence scores for all classes
scores = F.softmax(outputs.data, dim=1)
log_scores = scores.cpu().numpy()
log_images = images.cpu().numpy()
log_labels = labels.cpu().numpy()
log_preds = predicted.cpu().numpy()

# assign ids based on the order of the images
_id = 0
for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):

# add required info to data table:
# id, image pixels, model's guess, true label, scores for all classes
img_id = str(_id) + "_" + str(log_counter)
my_table.add_data(img_id, wandb.Image(i), p, l, *s)
_id += 1
if _id == LOG_IMAGES_PER_BATCH:
for i, (image, pred, label, score) in enumerate(
zip(*[x.cpu().numpy() for x in (images, predicted, labels, scores)])
):
# add required info to data table: id, image pixels, model's guess, true label, scores for all classes
my_table.add_data(f"{i}_{log_counter}", wandb.Image(image), pred, label, *score)
if i == LOG_IMAGES_PER_BATCH:
break


Expand Down Expand Up @@ -186,21 +180,11 @@ def test(model, device, test_loader):

# loop through the test data loader
for images, targets in test_loader:

# device conversion
images, targets = images.to(device), targets.to(device)

# forward pass -- generate predictions
outputs = model(images)

# sum up batch loss
test_loss += F.nll_loss(outputs, targets, reduction="sum").item()

# get the index of the max log-probability
_, predicted = torch.max(outputs.data, 1)

# compare predictions to true label
correct += (predicted == targets).sum().item()
images, targets = images.to(device), targets.to(device) # device conversion
outputs = model(images) # forward pass -- generate predictions
test_loss += F.nll_loss(outputs, targets, reduction="sum").item() # sum up batch loss
_, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability
correct += (predicted == targets).sum().item() # compare predictions to true label

# log predictions to the ``wandb`` table
if log_counter < NUM_BATCHES_TO_LOG:
Expand All @@ -216,22 +200,11 @@ def test(model, device, test_loader):
accuracy = float(correct) / len(test_loader.dataset)

# log the average loss, accuracy, and table
wandb.log(
{"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table}
)
wandb.log({"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table})

return accuracy


# %%
# Next, we define a function that runs for every epoch. It calls the ``train`` and ``test`` functions.
def epoch_step(
model, device, train_loader, test_loader, optimizer, epoch, log_interval
):
train(model, device, train_loader, optimizer, epoch, log_interval)
return test(model, device, test_loader)


# %%
# Hyperparameters
# ===============
Expand All @@ -242,14 +215,14 @@ def epoch_step(
class Hyperparameters(object):
"""
Args:
backend: pytorch backend to use, e.g. "gloo" or "nccl"
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
dir: directory where summary logs are stored
"""

backend: str = dist.Backend.GLOO
Expand Down Expand Up @@ -281,13 +254,18 @@ class Hyperparameters(object):
)


@task(retries=2, cache=True, cache_version="1.0", requests=Resources(gpu="1"))
def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
@task(
retries=2,
cache=True,
cache_version="1.0",
requests=Resources(gpu="1", mem="3Gi", storage="1Gi"),
limits=Resources(gpu="1", mem="3Gi", storage="1Gi"),
)
def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs:
wandb_setup()

# store the hyperparameters' config in ``wandb``
cfg = wandb.config
cfg.update(json.loads(hp.to_json()))
print(wandb.config)
wandb.config.update(json.loads(hp.to_json()))

# set random seed
torch.manual_seed(hp.seed)
Expand All @@ -298,35 +276,10 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
print(f"Use cuda {use_cuda}")
device = torch.device("cuda" if use_cuda else "cpu")

print("Using device: {}, world size: {}".format(device, WORLD_SIZE))

# load Data
# load data
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
training_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.batch_size,
shuffle=True,
**kwargs,
)
test_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.test_batch_size,
shuffle=False,
**kwargs,
)
training_data_loader = mnist_dataloader(hp.batch_size, train=True, **kwargs)
test_data_loader = mnist_dataloader(hp.batch_size, train=False, **kwargs)

# train the model
model = Net().to(device)
Expand All @@ -336,18 +289,11 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
)

# run multiple epochs and capture the accuracies for each epoch
accuracies = [
epoch_step(
model,
device,
train_loader=training_data_loader,
test_loader=test_data_loader,
optimizer=optimizer,
epoch=epoch,
log_interval=hp.log_interval,
)
for epoch in range(1, hp.epochs + 1)
]
# train the model: run multiple epochs and capture the accuracies for each epoch
accuracies = []
for epoch in range(1, hp.epochs + 1):
train(model, device, training_data_loader, optimizer, epoch, hp.log_interval)
accuracies.append(test(model, device, test_data_loader))

# after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile`
# type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format
Expand All @@ -364,10 +310,9 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
# Finally, we define a workflow to run the training algorithm. We return the model and accuracies.
@workflow
def pytorch_training_wf(
hp: Hyperparameters,
) -> (PythonPickledFile, typing.List[float]):
accuracies, model = train_mnist(hp=hp)
return model, accuracies
hp: Hyperparameters = Hyperparameters(epochs=10, batch_size=128)
) -> TrainingOutputs:
return pytorch_mnist_task(hp=hp)


# %%
Expand All @@ -377,9 +322,7 @@ def pytorch_training_wf(
# It is possible to run the model locally with almost no modifications (as long as the code takes care of resolving
# if the code is distributed or not). This is how we can do it:
if __name__ == "__main__":
model, accuracies = pytorch_training_wf(
hp=Hyperparameters(epochs=10, batch_size=128)
)
model, accuracies = pytorch_training_wf(hp=Hyperparameters(epochs=10, batch_size=128))
print(f"Model: {model}, Accuracies: {accuracies}")

# %%
Expand Down
Loading

0 comments on commit 95175dc

Please sign in to comment.