Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single node GPU training example (#333) #352

Merged
merged 2 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cookbook/case_studies/ml_training/mnist_classifier/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks

WORKDIR /root
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

# Give your wandb API key. Get it from https://wandb.ai/authorize.
# ENV WANDB_API_KEY your-api-key
# Set your wandb API key and user name. Get the API key from https://wandb.ai/authorize.
# ENV WANDB_API_KEY <api_key>
# ENV WANDB_USERNAME <user_name>

# Install the AWS cli for AWS support
RUN pip install awscli

ENV VENV /opt/venv

# Virtual environment
ENV VENV /opt/venv
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

Expand All @@ -25,6 +25,10 @@ RUN pip install -r /root/requirements.txt
# Copy the actual code
COPY mnist_classifier/ /root/mnist_classifier/

# Copy the makefile targets to expose on the container. This makes it easier to register.
COPY in_container.mk /root/Makefile
COPY mnist_classifier/sandbox.config /root

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Single GPU Training
-------------------
Single Node, Single GPU Training
--------------------------------

Training a model on a single node on one GPU is as trivial as writing any Flyte task and simply setting the GPU to ``1``.
As long as the Docker image is built correctly with the right version of the GPU drivers and the Flyte backend is
Expand Down Expand Up @@ -31,29 +31,22 @@
from torchvision import datasets, transforms

# %%
# Let's define some variables to be used later.
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))

# %%
# The following variables are specific to ``wandb``:
# Let's define some variables to be used later. The following variables are specific to ``wandb``:
#
# - ``NUM_BATCHES_TO_LOG``: Number of batches to log from the test data for each test step
# - ``LOG_IMAGES_PER_BATCH``: Number of images to log per test batch
NUM_BATCHES_TO_LOG = 10
LOG_IMAGES_PER_BATCH = 32

# %%
# If running remotely, copy your ``wandb`` API key to the Dockerfile. Next, login to ``wandb``.
# You can disable this if you're already logged in on your local machine.
wandb.login()

# %%
# Next, we initialize the ``wandb`` project.
# If running remotely, copy your ``wandb`` API key to the Dockerfile under the environment variable ``WANDB_API_KEY``.
# This function logs into ``wandb`` and initializes the project. If you built your Docker image with the
# ``WANDB_USERNAME``, this will work. Otherwise, replace ``my-user-name`` with your ``wandb`` user name.
#
# .. admonition:: MUST DO!
#
# Replace ``entity`` value with your username.
wandb.init(project="mnist-single-node", entity="your-user-name")
# We'll call this function in the ``pytorch_mnist_task`` defined below.
def wandb_setup():
wandb.login()
wandb.init(project="mnist-single-node-single-gpu", entity=os.environ.get("WANDB_USERNAME", "my-user-name"))

# %%
# Creating the Network
Expand Down Expand Up @@ -81,6 +74,26 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


# %%
# The Data Loader
# ===============

def mnist_dataloader(batch_size, train=True, **kwargs):
return torch.utils.data.DataLoader(
datasets.MNIST(
"./data",
train=train,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=batch_size,
shuffle=True,
**kwargs,
)


# %%
# Training
# ========
Expand All @@ -95,24 +108,12 @@ def train(model, device, train_loader, optimizer, epoch, log_interval):

# loop through the training batches
for batch_idx, (data, target) in enumerate(train_loader):

# device conversion
data, target = data.to(device), target.to(device)

# clear gradient
optimizer.zero_grad()

# forward pass
output = model(data)

# compute loss
loss = F.nll_loss(output, target)

# propagate the loss backward
loss.backward()

# update the model parameters
optimizer.step()
data, target = data.to(device), target.to(device) # device conversion
optimizer.zero_grad() # clear gradient
output = model(data) # forward pass
loss = F.nll_loss(output, target) # compute loss
loss.backward() # propagate the loss backward
optimizer.step() # update the model parameters

if batch_idx % log_interval == 0:
print(
Expand All @@ -133,26 +134,19 @@ def train(model, device, train_loader, optimizer, epoch, log_interval):
# We define a test logger function which will be called when we run the model on test dataset.
def log_test_predictions(images, labels, outputs, predicted, my_table, log_counter):
"""
Convenience funtion to log predictions for a batch of test images
Convenience function to log predictions for a batch of test images
"""

# obtain confidence scores for all classes
scores = F.softmax(outputs.data, dim=1)
log_scores = scores.cpu().numpy()
log_images = images.cpu().numpy()
log_labels = labels.cpu().numpy()
log_preds = predicted.cpu().numpy()

# assign ids based on the order of the images
_id = 0
for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):

# add required info to data table:
# id, image pixels, model's guess, true label, scores for all classes
img_id = str(_id) + "_" + str(log_counter)
my_table.add_data(img_id, wandb.Image(i), p, l, *s)
_id += 1
if _id == LOG_IMAGES_PER_BATCH:
for i, (image, pred, label, score) in enumerate(
zip(*[x.cpu().numpy() for x in (images, predicted, labels, scores)])
):
# add required info to data table: id, image pixels, model's guess, true label, scores for all classes
my_table.add_data(f"{i}_{log_counter}", wandb.Image(image), pred, label, *score)
if i == LOG_IMAGES_PER_BATCH:
break


Expand Down Expand Up @@ -186,21 +180,11 @@ def test(model, device, test_loader):

# loop through the test data loader
for images, targets in test_loader:

# device conversion
images, targets = images.to(device), targets.to(device)

# forward pass -- generate predictions
outputs = model(images)

# sum up batch loss
test_loss += F.nll_loss(outputs, targets, reduction="sum").item()

# get the index of the max log-probability
_, predicted = torch.max(outputs.data, 1)

# compare predictions to true label
correct += (predicted == targets).sum().item()
images, targets = images.to(device), targets.to(device) # device conversion
outputs = model(images) # forward pass -- generate predictions
test_loss += F.nll_loss(outputs, targets, reduction="sum").item() # sum up batch loss
_, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability
correct += (predicted == targets).sum().item() # compare predictions to true label

# log predictions to the ``wandb`` table
if log_counter < NUM_BATCHES_TO_LOG:
Expand All @@ -216,22 +200,11 @@ def test(model, device, test_loader):
accuracy = float(correct) / len(test_loader.dataset)

# log the average loss, accuracy, and table
wandb.log(
{"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table}
)
wandb.log({"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table})

return accuracy


# %%
# Next, we define a function that runs for every epoch. It calls the ``train`` and ``test`` functions.
def epoch_step(
model, device, train_loader, test_loader, optimizer, epoch, log_interval
):
train(model, device, train_loader, optimizer, epoch, log_interval)
return test(model, device, test_loader)


# %%
# Hyperparameters
# ===============
Expand All @@ -242,14 +215,14 @@ def epoch_step(
class Hyperparameters(object):
"""
Args:
backend: pytorch backend to use, e.g. "gloo" or "nccl"
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
dir: directory where summary logs are stored
"""

backend: str = dist.Backend.GLOO
Expand Down Expand Up @@ -281,13 +254,18 @@ class Hyperparameters(object):
)


@task(retries=2, cache=True, cache_version="1.0", requests=Resources(gpu="1"))
def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
@task(
retries=2,
cache=True,
cache_version="1.0",
requests=Resources(gpu="1", mem="3Gi", storage="1Gi"),
limits=Resources(gpu="1", mem="3Gi", storage="1Gi"),
)
def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs:
wandb_setup()

# store the hyperparameters' config in ``wandb``
cfg = wandb.config
cfg.update(json.loads(hp.to_json()))
print(wandb.config)
wandb.config.update(json.loads(hp.to_json()))

# set random seed
torch.manual_seed(hp.seed)
Expand All @@ -298,35 +276,10 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
print(f"Use cuda {use_cuda}")
device = torch.device("cuda" if use_cuda else "cpu")

print("Using device: {}, world size: {}".format(device, WORLD_SIZE))

# load Data
# load data
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
training_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.batch_size,
shuffle=True,
**kwargs,
)
test_data_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=hp.test_batch_size,
shuffle=False,
**kwargs,
)
training_data_loader = mnist_dataloader(hp.batch_size, train=True, **kwargs)
test_data_loader = mnist_dataloader(hp.batch_size, train=False, **kwargs)

# train the model
model = Net().to(device)
Expand All @@ -336,18 +289,11 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
)

# run multiple epochs and capture the accuracies for each epoch
accuracies = [
epoch_step(
model,
device,
train_loader=training_data_loader,
test_loader=test_data_loader,
optimizer=optimizer,
epoch=epoch,
log_interval=hp.log_interval,
)
for epoch in range(1, hp.epochs + 1)
]
# train the model: run multiple epochs and capture the accuracies for each epoch
accuracies = []
for epoch in range(1, hp.epochs + 1):
train(model, device, training_data_loader, optimizer, epoch, hp.log_interval)
accuracies.append(test(model, device, test_data_loader))

# after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile`
# type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format
Expand All @@ -364,10 +310,9 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs:
# Finally, we define a workflow to run the training algorithm. We return the model and accuracies.
@workflow
def pytorch_training_wf(
hp: Hyperparameters,
) -> (PythonPickledFile, typing.List[float]):
accuracies, model = train_mnist(hp=hp)
return model, accuracies
hp: Hyperparameters = Hyperparameters(epochs=10, batch_size=128)
) -> TrainingOutputs:
return pytorch_mnist_task(hp=hp)


# %%
Expand All @@ -377,9 +322,7 @@ def pytorch_training_wf(
# It is possible to run the model locally with almost no modifications (as long as the code takes care of resolving
# if the code is distributed or not). This is how we can do it:
if __name__ == "__main__":
model, accuracies = pytorch_training_wf(
hp=Hyperparameters(epochs=10, batch_size=128)
)
model, accuracies = pytorch_training_wf(hp=Hyperparameters(epochs=10, batch_size=128))
print(f"Model: {model}, Accuracies: {accuracies}")

# %%
Expand Down
Loading