Skip to content

Commit

Permalink
Add readme and docker file
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Dec 18, 2024
1 parent 6416dd3 commit 7ebd346
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 147 deletions.
35 changes: 35 additions & 0 deletions experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# syntax=docker/dockerfile:experimental
# Use Python 3.10 as the base image
FROM python:3.10-slim-bullseye

# Install system dependencies
RUN apt-get update && apt-get upgrade -y
RUN apt-get update && apt-get install -y curl gnupg

# Add the Google Cloud SDK package repository
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk git

# Set the default Python version to 3.10
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip install optax fire tensorflow tensorboard-plugin-profile
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

WORKDIR /
RUN git clone https://github.com/pytorch/torchtitan.git
WORKDIR /torchtitan
RUN pip install -r requirements.txt
RUN pip install .

WORKDIR /
RUN git clone https://github.com/pytorch/xla.git
WORKDIR xla/experimental/torch_xla2
RUN git checkout hanq_hybrid_mesh
RUN pip install -e .

ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"]
CMD ["--batch_size=8", "--seqlen=2048"]
15 changes: 15 additions & 0 deletions experimental/torch_xla2/examples/train_llama_torchtitan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Training based on torchtitan llama model
====================================

```bash
python train_llama.py
```



## Detailed numbers

### v5p-8

seqlen = 8192
bs = 8
228 changes: 85 additions & 143 deletions experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from collections import defaultdict
import functools


def _setup_default_env():
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1')
# only need for tpu v4
Expand All @@ -22,6 +22,7 @@ def _setup_default_env():

import torch_xla2
import torch_xla2.interop
import torch_xla2.train
from torch_xla2.interop import jax_view, torch_view, JittableModule
import jax
import jax.numpy as jnp
Expand All @@ -34,10 +35,6 @@ def _setup_default_env():

P = jax.sharding.PartitionSpec



SEQLEN = 8192
BATCH = 8
global_axis: Tuple[str, str] = ('fsdp', )
num_global_devices = jax.device_count()
num_local_devices = jax.local_device_count()
Expand All @@ -56,170 +53,97 @@ def sharded_device_put(tensor, sharding):
return jax.make_array_from_single_device_arrays(shape, sharding, x_split)


class FSDPv2(torch.nn.Module):

def __init__(self, mod):
super().__init__()
self.mod = mod
self.mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(num_partitions),
axis_names=global_axis,
)
self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis))

def forward(self, *args):
args = list(args)
args[0] = self.shard(args[0])
res = self.mod(*args)
return self.shard(res)

def shard(self, x):
return torch_xla2.interop.call_jax(
jax.lax.with_sharding_constraint,
x,
self.sharding,
)

def print_shapes(pyt):
for p in pytree.tree_flatten(pyt)[0]:
if hasattr(p, 'shape'):
print(p.shape, p.dtype)


class Module(torch.nn.Module):

def __init__(self, inner):
super().__init__()
self.inner = FSDPv2(inner)

def training_step(self, data, batch_id):
x, y = data
logits = self.inner(x)
num_tokens = logits.shape[-1]
logits = logits.reshape(-1, num_tokens)
y = y.reshape(-1)
return torch.nn.functional.cross_entropy(
logits, y)


class Trainer:

def __init__(self):
self.mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(num_partitions),
axis_names=global_axis,
)
def __init__(self, mesh):
self.mesh = mesh
self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis))
self.replicated = jax.sharding.NamedSharding(self.mesh, P())

def _shard_fsdp_style(self, state_dict, sharding=None):
if sharding is None:
sharding = self.x_sharding
def move_one_tensor(x):
jval = torch_xla2.tensor.t2j(x)
return sharded_device_put(jval, sharding)

x = x.to('jax')
return x.apply(sharded_device_put, sharding)
if isinstance(state_dict, torch.Tensor):
return move_one_tensor(state_dict)
res = {}
for k, v in sorted(state_dict.items()):
res[k] = move_one_tensor(v)
return res

def fit(self, lightning_mod, data_loader):
def fit(self, model, loss_fn, data_loader):
xla_env = torch_xla2.default_env()
jax.config.update('jax_enable_x64', False)
xla_env._mesh = self.mesh
xla_env.use_flash_attention = True

jittable_mod = JittableModule(lightning_mod)
jax_params = self._shard_fsdp_style(jittable_mod.params)
jax_buffers = self._shard_fsdp_style(jittable_mod.buffers)

@jax.checkpoint
def lightning_mod_loss(
weights: jax.Array, buffers: jax.Array, data: jax.Array, batch_id):
"""returns loss"""
with jax.named_scope("Computing_loss"):
weights, buffers, data = torch_view((weights, buffers, data))
# NOTE: these is needed because the original model
# did not register those as persistent buffer
with xla_env:
loss = jittable_mod.functional_call(
'training_step',
weights, buffers, data, batch_id)
return jax_view(loss)

model.to('jax')
jittable_mod = JittableModule(model)

jax_optimizer = optax.adamw(0.001)
# split the params to the n devices
jittable_mod.params = self._shard_fsdp_style(jittable_mod.params)
jittable_mod.buffers = self._shard_fsdp_style(jittable_mod.buffers)

opt_state = jax_optimizer.init(jax_params)
grad_fn = jax.value_and_grad(lightning_mod_loss)
def model_fn(weights, buffers, args):
return jittable_mod.functional_call('forward', weights, buffers, args)

opt_state_sharding = jax.tree_util.tree_map(lambda p : p.sharding, opt_state)

print('Begining training')

@functools.partial(
jax.jit,
donate_argnums=(0, 2),
)
def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid):
print('Tracing inside of step')
with jax.named_scope("Computing_loss_and_grad"):
loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid)
with jax.named_scope("optimizer_updates"):
updates, opt_state = jax_optimizer.update(
grads, optimizer_state, jax_weights)
jax_weights = optax.apply_updates(jax_weights, updates)
return loss, jax_weights, opt_state

total_param_size = 0
for k, v in jax_params.items():
total_param_size += v.size

print('Total number of params: ', total_param_size)

print('Start compiling')
start = time.perf_counter()
lowered = step.lower(
jax_params, jax_buffers, opt_state,
(jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding),
jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)),
0
)
# print(lowered.as_text())
print('program size:', len(lowered.as_text()) / 1e6, 'm chars')
step_compiled = lowered.compile()
end = time.perf_counter()
compile_time = end - start
print('End compiling', compile_time)
jax_optimizer = optax.adamw(0.001)
opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params)

for co in step_compiled.cost_analysis():
print('flops counter:', co['flops'])
train_step = torch_xla2.train.make_train_step(
model_fn, loss_fn, jax_optimizer,
remat_policy=jax.checkpoint_policies.nothing_saveable,
mark_fsdp_sharding_axis='fsdp')

print('Begining training')
s = time.perf_counter()
jax.profiler.start_trace('/tmp/tensorboard')
print('start training')
min_loop_time = 10000
for i, item in enumerate(data_loader):
inputs, labels = sharded_device_put(jax_view(xla_env.to_xla(item)),
self.x_sharding)
print('INPUT shape', inputs.shape)
inputs, labels = item
# Move them to jax device
inputs = inputs.to('jax')
labels = labels.to('jax')

# Shard them on batch dim for fsdp
inputs.apply_(sharded_device_put, self.x_sharding)
labels.apply_(sharded_device_put, self.x_sharding)

print('INPUT shape', inputs.shape)
step_start = time.perf_counter()
loss, jax_params, opt_state = step_compiled(
jax_params, jax_buffers, opt_state, (inputs, labels), 0)
jax.block_until_ready((loss, jax_params))
loss, jittable_mod.params, opt_state = train_step(
jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels)
# wait for iteration to finish to measure time
jax.block_until_ready((loss, jittable_mod.params))
step_end = time.perf_counter()
print(i, 'loss', loss, 'step latency: ', step_end - step_start)
loop_time = step_end - step_start
min_loop_time = min(min_loop_time, loop_time)
print('======')
if i >= 2:
if i >= 3:
break
jax.profiler.stop_trace()
return min_loop_time, compile_time
return min_loop_time


def create_sharded_weights(state_dict, sharding):
res = {}
env = torch_xla2.default_env()
for name, weight_meta in state_dict.items():
with jax.default_device(jax.devices('cpu')[0]):
weight_torch = torch.randn(
weight_meta.shape,
dtype=weight_meta.dtype)
# weight_jax is jax array
weight_jax = env.to_xla(weight_torch).jax()
res[name] = env.j2t_iso(jax.make_array_from_callback(
weight_jax.shape, sharding, lambda a: weight_jax[a]
))
return res


def fake_dataloader(size, seqlen, batch_size):
Expand All @@ -232,33 +156,51 @@ def main(
model_type='8B',
batch_size=8,
seqlen=2048,
mode='regular',
override_num_layers=-1,
):
logging.getLogger("jax").setLevel(logging.DEBUG)
torch_xla2.enable_globally()
#logging.getLogger("jax").setLevel(logging.DEBUG)
print(f"Running with parameters {locals()}")
global SEQLEN
global BATCH
SEQLEN = seqlen
BATCH = batch_size

mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', ))
sharding = jax.sharding.NamedSharding(mesh, P('fsdp'))

env = torch_xla2.default_env()
env.config.use_tpu_flash_attention = use_flash_attention
env.config.shmap_flash_attention = use_flash_attention
env.config.use_tpu_flash_attention = True
env.config.shmap_flash_attention = True
env._mesh = mesh # this is the mesh used by flash attention pallas kernel

args = llama3_configs[model_type]
#with torch.device('meta'):
gpt = titan.Transformer(args)

light_mod = Module(gpt)
light_mod.to(torch.bfloat16)

# Note: torchtitan's upstream config did not specify this value
args.vocab_size = 128256
if override_num_layers > 0:
args.n_layers = override_num_layers

# Note: because a single device don't have enough HBM memory
# nor enough CPU memory to hold the parameters. We instantiate
# the model on meta then manually initialize then shard each param
torch.set_default_dtype(torch.bfloat16)
with torch.device('meta'):
gpt = titan.Transformer(args)
gpt.to(torch.bfloat16)

state_dict = create_sharded_weights(gpt.state_dict(), sharding)
gpt.load_state_dict(state_dict, assign=True)

train_loader = fake_dataloader(10, seqlen, batch_size)

def loss_fn(logits, y):
num_tokens = logits.shape[-1]
logits = logits.reshape(-1, num_tokens)
y = y.reshape(-1)
return torch.nn.functional.cross_entropy(
logits, y)

with mesh:
trainer = Trainer()
return trainer.fit(
light_mod,
gpt,
loss_fn,
train_loader
)

Expand Down
10 changes: 9 additions & 1 deletion experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def device(self):
def jax_device(self):
return self._elem.device

def apply(self, jax_function, *args, **kwargs):
# Call a jax function on _elem
res = jax_function(self._elem, *args, **kwargs)
return self._env.j2t_iso(res)

def apply_(self, jax_function, *args, **kwargs):
self._elem = jax_function(self._elem, *args, **kwargs)

def tolist(self):
return self._elem.tolist()

Expand Down Expand Up @@ -294,7 +302,7 @@ def get_as_jax_device(self, device: Any):
if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
return None

if device in ('jax_cpu', 'cpu'):
if device == 'cpu':
return jax.devices('cpu')[0]
return jax.devices()[0]

Expand Down
Loading

0 comments on commit 7ebd346

Please sign in to comment.