Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training UX: automatic generating make_train_step #8495

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 26 additions & 85 deletions experimental/torch_xla2/examples/basic_training_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
"""

import functools
from torch_xla2 import train, interop
import torch
from torch.utils import _pytree as pytree
import torchvision
Expand All @@ -17,6 +19,8 @@
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

env = torch_xla2.enable_globally()


transform = transforms.Compose(
[transforms.ToTensor(),
Expand All @@ -38,29 +42,7 @@
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))


import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -83,62 +65,55 @@ def forward(self, x):
model = GarmentClassifier()
loss_fn = torch.nn.CrossEntropyLoss()

jax_weights, jax_func = torch_xla2.extract_jax(model)
jax_func = jax.jit(jax_func, inline=True)
jax_optimizer = optax.adam(0.01)
opt_state = jax_optimizer.init(jax_weights)

model.to('jax') # move the model to jax device
model_jittable = interop.JittableModule(model)
weights = model_jittable.params # these are trainable parameters
buffers = model_jittable.buffers # these are non-trainable parameters

def jax_loss(weights, data, label):
pred = jax_func(weights, data)
loss = torch_xla2.interop.call_torch(loss_fn, pred, label)
return loss
opt_state = interop.call_jax(jax_optimizer.init, weights)
model_fn = functools.partial(model_jittable.functional_call, 'forward')

grad_fn = jax.jit(jax.value_and_grad(jax_loss))
train_step = train.make_train_step(model_fn, loss_fn, jax_optimizer)

train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does donate_argnums here imply that input buffers are donated to outputs? The (0, 2) is pretty cryptic to me. Consider commenting on their meaning.

Or better, maybe this could be handled internally? We could jit the function inside make_train_step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The current issue is that sometimes I want to print out the stablehlo for inspection. So need to make the jax_jit'd object also to store the jax function. I'll followup.


# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
dummy_inputs = torch.rand(4, 28, 28).to('jax')
dummy_outputs = torch.rand(4, 10).to('jax')
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))

dummy_labels = torch.tensor([1, 5, 3, 7]).to('jax')

def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
# test train_step

def train_one_epoch(weights, buffers, opt_state, epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.

# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
data = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.t2j, data)
inputs, labels = data

val, grads = grad_fn(jax_weights, (inputs, ), labels)
updates, opt_state = jax_optimizer.update(grads, opt_state)
jax_weights = optax.apply_updates(jax_weights, updates)
inputs = inputs.to('jax')
labels = labels.to('jax')

loss, weights, opt_state = train_step(
weights, buffers, opt_state, inputs, labels)

# Gather data and report
running_loss += val.item()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.

return last_loss, opt_state
return last_loss, weights, opt_state



Expand All @@ -152,39 +127,5 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))

# Make sure gradient tracking is on, and do a pass over the data
model.train(True)

avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer)

running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()

# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):

vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata)
voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward
vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels)
running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) },
epoch_number + 1)
writer.flush()

# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)

epoch_number += 1
avg_loss, weights, opt_state = train_one_epoch(weights, buffers, opt_state, epoch_number, writer)
print(avg_loss)
34 changes: 34 additions & 0 deletions experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# syntax=docker/dockerfile:experimental
# Use Python 3.10 as the base image
FROM python:3.10-slim-bullseye

# Install system dependencies
RUN apt-get update && apt-get upgrade -y
RUN apt-get update && apt-get install -y curl gnupg

# Add the Google Cloud SDK package repository
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk git

# Set the default Python version to 3.10
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip install optax fire tensorflow tensorboard-plugin-profile
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

WORKDIR /
RUN git clone https://github.com/pytorch/torchtitan.git
WORKDIR /torchtitan
RUN pip install -r requirements.txt
RUN pip install .

WORKDIR /
RUN git clone https://github.com/pytorch/xla.git
WORKDIR xla/experimental/torch_xla2
RUN pip install -e .

ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"]
CMD ["--batch_size=8", "--seqlen=2048"]
Loading