Skip to content

Commit

Permalink
Single node GPU training example
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored and cosmicBboy committed Jul 27, 2021
1 parent 4a77808 commit 1aea0e9
Show file tree
Hide file tree
Showing 13 changed files with 550 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cookbook/case_studies/ml_training/pytorch/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include ../../../common/Makefile
include ../../../common/parent.mk
Empty file.
24 changes: 24 additions & 0 deletions cookbook/case_studies/ml_training/pytorch/in_container.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
SERIALIZED_PB_OUTPUT_DIR := /tmp/output

.PHONY: clean
clean:
rm -rf $(SERIALIZED_PB_OUTPUT_DIR)/*

$(SERIALIZED_PB_OUTPUT_DIR): clean
mkdir -p $(SERIALIZED_PB_OUTPUT_DIR)

.PHONY: serialize
serialize: $(SERIALIZED_PB_OUTPUT_DIR)
pyflyte --config /root/sandbox.config serialize workflows -f $(SERIALIZED_PB_OUTPUT_DIR)

.PHONY: register
register: serialize
flyte-cli register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development -v ${VERSION} --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} $(SERIALIZED_PB_OUTPUT_DIR)/*

.PHONY: fast_serialize
fast_serialize: $(SERIALIZED_PB_OUTPUT_DIR)
pyflyte --config /root/sandbox.config serialize fast workflows -f $(SERIALIZED_PB_OUTPUT_DIR)

.PHONY: fast_register
fast_register: fast_serialize
flyte-cli fast-register-files -h ${FLYTE_HOST} ${INSECURE_FLAG} -p ${PROJECT} -d development --kubernetes-service-account ${SERVICE_ACCOUNT} --output-location-prefix ${OUTPUT_DATA_PREFIX} --additional-distribution-dir ${ADDL_DISTRIBUTION_DIR} $(SERIALIZED_PB_OUTPUT_DIR)/*
30 changes: 30 additions & 0 deletions cookbook/case_studies/ml_training/pytorch/single_node/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks

WORKDIR /root
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

# Install the AWS cli for AWS support
RUN pip install awscli

ENV VENV /opt/venv
# Virtual environment
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

# Install Python dependencies
COPY single_node/requirements.txt /root
RUN pip install -r /root/requirements.txt

# Copy the makefile targets to expose on the container. This makes it easier to register.
COPY in_container.mk /root/Makefile

# Copy the actual code
COPY single_node/ /root/single_node/

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
ENV FLYTE_INTERNAL_IMAGE $tag
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
PREFIX=single_node
include ../../../../common/Makefile
include ../../../../common/leaf.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _pytorch-training:

Train a Pytorch model on one GPU
---------------------------------


Walkthrough
====================
Empty file.
293 changes: 293 additions & 0 deletions cookbook/case_studies/ml_training/pytorch/single_node/gpu_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""
Single Node GPU training
-------------------------
Training a model on a single node on one gpu is as trivial as writing any Flyte task and simply setting the GPU='1'.
As long as the docker image is built correctly with the right version of the GPU drivers and your Flyte backend is
provisioned to have GPU machines, Flyte will execute your task on a node that has GPU's.
Currently Flyte does not provide any specific task type for Pytorch (though it is entire possible to provide a task-type
that supports pytorch-ignite or pytorch-lightening support, but this is not critical). You can request for a GPU, simply
by setting the GPU="1" resource request and then at runtime you will receive a GPU.
This example shows how you can create any Pytorch model and train it using Flyte and your specialized container. Flyte
will handle the data passing and can export the model out of the training environment.
"""
import os
import typing
from dataclasses import dataclass

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dataclasses_json import dataclass_json
from flytekit import Resources, task, workflow
from flytekit.types.directory import TensorboardLogs
from flytekit.types.file import PNGImageFile, PythonPickledFile
from tensorboardX import SummaryWriter
from torch import distributed as dist
from torch import nn, optim
from torchvision import datasets, transforms


WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))


# %%
# Actual model
# ============
# this the Model class with all the layers
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


# %%
# Trainer
# =======
# This is the core training loop, which runs for 1 epoch. The loss values are written to the SummaryWriter
def train(model, device, train_loader, optimizer, epoch, writer, log_interval):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
niter = epoch * len(train_loader) + batch_idx
writer.add_scalar("loss", loss.item(), niter)


# %%
# Test the model
# ==============
# This method calculates the accuracy for the given model per epoch
def test(model, device, test_loader, writer, epoch):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.max(1, keepdim=True)[
1
] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print("\naccuracy={:.4f}\n".format(float(correct) / len(test_loader.dataset)))
accuracy = float(correct) / len(test_loader.dataset)
writer.add_scalar("accuracy", accuracy, epoch)
return accuracy


def epoch_step(
model, device, train_loader, test_loader, optimizer, epoch, writer, log_interval
):
train(model, device, train_loader, optimizer, epoch, writer, log_interval)
return test(model, device, test_loader, writer, epoch)


# %%
# Training Hyperparameters
# ========================
#
@dataclass_json
@dataclass
class Hyperparameters(object):
"""
Args:
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
dir: directory where summary logs are stored
"""

backend: str = dist.Backend.GLOO
sgd_momentum: float = 0.5
seed: int = 1
log_interval: int = 10
batch_size: int = 64
test_batch_size: int = 1000
epochs: int = 10
learning_rate: float = 0.01


# %%
# Actual Training algorithm
# =========================
# The output model using `torch.save` saves the `state_dict` as described
# `in pytorch docs <https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-and-loading-models>`_.
# A common convention is to have the ``.pt`` extension for the file
#
# Notice we are also generating an output variable called logs, these logs can be used to visualize the training in
# Tensorboard and are the output of the `SummaryWriter` interface
# Refer to section :ref:`pytorch_tensorboard` to visualize the outputs of this example.
#
# .. note::
#
# Note the usage of requests=Resources(gpu="1"). This will force Flyte to allocate this task onto a machine with GPUs
# The task will be queued up until a machine with gpu's can be procured. Also for the GPU Training to work, the
# dockerfile needs to be built as explained in the :ref:`pytorch-training` section.
#
TrainingOutputs = typing.NamedTuple(
"TrainingOutputs",
epoch_accuracies=typing.List[float],
model_state=PythonPickledFile,
logs=TensorboardLogs,
)


@task(retries=2, cache=True, cache_version="1.0", requests=Resources(gpu="1"))
def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
log_dir = "logs"
writer = SummaryWriter(log_dir)

torch.manual_seed(hp.seed)

# Ideally if GPU training is required, and if cuda is not available, we can raise an Exception, but as we want
# this algorithm to work locally as well (and most users dont have a GPU locally), we will fallback to using a CPU
use_cuda = torch.cuda.is_available()
print(f"Use cuda {use_cuda}")
device = torch.device("cuda" if use_cuda else "cpu")

print("Using device: {}, world size: {}".format(device, WORLD_SIZE))

# LOAD Data
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
training_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.batch_size,
shuffle=True,
**kwargs,
)
test_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.test_batch_size,
shuffle=False,
**kwargs,
)

# Train the model
model = Net().to(device)

optimizer = optim.SGD(
model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum
)

# Run multiple epochs and capture the accuracies for each epoch
accuracies = [
epoch_step(
model,
device,
train_loader=training_data_loader,
test_loader=test_data_loader,
optimizer=optimizer,
epoch=epoch,
writer=writer,
log_interval=hp.log_interval,
)
for epoch in range(1, hp.epochs + 1)
]

# After training the model, we can simply save it to disk and return it from the Flyte task as a FlyteFile type
# called the PythonPickledFile. PythonPickledFile is simply a decorator on the FlyteFile, that records the format
# of the serialized model as ``pickled``
model_file = "mnist_cnn.pt"
torch.save(model.state_dict(), model_file)

return TrainingOutputs(
epoch_accuracies=accuracies,
model_state=PythonPickledFile(model_file),
logs=TensorboardLogs(log_dir),
)


# %%
# Let us plot the accuracy
# ========================
# We will output the accuracy plot as a PNG image
@task
def plot_accuracy(epoch_accuracies: typing.List[float]) -> PNGImageFile:
# summarize history for accuracy
plt.plot(epoch_accuracies)
plt.title("Accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
accuracy_plot = "accuracy.png"
plt.savefig(accuracy_plot)

return PNGImageFile(accuracy_plot)


# %%
# Create a pipeline
# =================
# now the training and the plotting can be together put into a pipeline, in which case the training is performed first
# followed by the plotting of the accuracy. Data is passed between them and the workflow itself outputs the image and
# the serialize model
@workflow
def pytorch_training_wf(
hp: Hyperparameters,
) -> (PythonPickledFile, PNGImageFile, TensorboardLogs):
accuracies, model, logs = train_mnist(hp=hp)
plot = plot_accuracy(epoch_accuracies=accuracies)
return model, plot, logs


# %%
# Run the model locally
# =====================
# It is possible to run the model locally with almost no modifications (as long as the code takes care of the resolving
# if distributed or not)
if __name__ == "__main__":
model, plot, logs = pytorch_training_wf(
hp=Hyperparameters(epochs=2, batch_size=128)
)
print(f"Model: {model}, plot PNG: {plot}, Tensorboard Log Dir: {logs}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r ../../../../common/requirements-common.in
torch
torchvision
Loading

0 comments on commit 1aea0e9

Please sign in to comment.