From 7ebd34664503503eac37587bc425cb5c75aca67a Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 18 Dec 2024 20:01:03 +0000 Subject: [PATCH] Add readme and docker file --- .../train_llama_torchtitan/Dockerfile | 35 +++ .../examples/train_llama_torchtitan/README.md | 15 ++ .../train_llama_torchtitan/train_llama.py | 228 +++++++----------- experimental/torch_xla2/torch_xla2/tensor.py | 10 +- experimental/torch_xla2/torch_xla2/train.py | 9 +- 5 files changed, 150 insertions(+), 147 deletions(-) create mode 100644 experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile create mode 100644 experimental/torch_xla2/examples/train_llama_torchtitan/README.md diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile b/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile new file mode 100644 index 00000000000..dd7e74024f4 --- /dev/null +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile @@ -0,0 +1,35 @@ +# syntax=docker/dockerfile:experimental +# Use Python 3.10 as the base image +FROM python:3.10-slim-bullseye + +# Install system dependencies +RUN apt-get update && apt-get upgrade -y +RUN apt-get update && apt-get install -y curl gnupg + +# Add the Google Cloud SDK package repository +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list +RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - + +# Install the Google Cloud SDK +RUN apt-get update && apt-get install -y google-cloud-sdk git + +# Set the default Python version to 3.10 +RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 +RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN pip install optax fire tensorflow tensorboard-plugin-profile +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +WORKDIR / +RUN git clone https://github.com/pytorch/torchtitan.git +WORKDIR /torchtitan +RUN pip install -r requirements.txt +RUN pip install . + +WORKDIR / +RUN git clone https://github.com/pytorch/xla.git +WORKDIR xla/experimental/torch_xla2 +RUN git checkout hanq_hybrid_mesh +RUN pip install -e . + +ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] +CMD ["--batch_size=8", "--seqlen=2048"] \ No newline at end of file diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/README.md b/experimental/torch_xla2/examples/train_llama_torchtitan/README.md new file mode 100644 index 00000000000..9519eaa9dba --- /dev/null +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/README.md @@ -0,0 +1,15 @@ +Training based on torchtitan llama model +==================================== + +```bash +python train_llama.py +``` + + + +## Detailed numbers + +### v5p-8 + +seqlen = 8192 +bs = 8 diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py b/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py index b7bdfcc7615..edb714b4387 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py @@ -5,8 +5,8 @@ from collections import defaultdict import functools + def _setup_default_env(): - os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') # only need for tpu v4 @@ -22,6 +22,7 @@ def _setup_default_env(): import torch_xla2 import torch_xla2.interop +import torch_xla2.train from torch_xla2.interop import jax_view, torch_view, JittableModule import jax import jax.numpy as jnp @@ -34,10 +35,6 @@ def _setup_default_env(): P = jax.sharding.PartitionSpec - - -SEQLEN = 8192 -BATCH = 8 global_axis: Tuple[str, str] = ('fsdp', ) num_global_devices = jax.device_count() num_local_devices = jax.local_device_count() @@ -56,59 +53,10 @@ def sharded_device_put(tensor, sharding): return jax.make_array_from_single_device_arrays(shape, sharding, x_split) -class FSDPv2(torch.nn.Module): - - def __init__(self, mod): - super().__init__() - self.mod = mod - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) - self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis)) - - def forward(self, *args): - args = list(args) - args[0] = self.shard(args[0]) - res = self.mod(*args) - return self.shard(res) - - def shard(self, x): - return torch_xla2.interop.call_jax( - jax.lax.with_sharding_constraint, - x, - self.sharding, - ) - -def print_shapes(pyt): - for p in pytree.tree_flatten(pyt)[0]: - if hasattr(p, 'shape'): - print(p.shape, p.dtype) - - -class Module(torch.nn.Module): - - def __init__(self, inner): - super().__init__() - self.inner = FSDPv2(inner) - - def training_step(self, data, batch_id): - x, y = data - logits = self.inner(x) - num_tokens = logits.shape[-1] - logits = logits.reshape(-1, num_tokens) - y = y.reshape(-1) - return torch.nn.functional.cross_entropy( - logits, y) - - class Trainer: - def __init__(self): - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) + def __init__(self, mesh): + self.mesh = mesh self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis)) self.replicated = jax.sharding.NamedSharding(self.mesh, P()) @@ -116,9 +64,8 @@ def _shard_fsdp_style(self, state_dict, sharding=None): if sharding is None: sharding = self.x_sharding def move_one_tensor(x): - jval = torch_xla2.tensor.t2j(x) - return sharded_device_put(jval, sharding) - + x = x.to('jax') + return x.apply(sharded_device_put, sharding) if isinstance(state_dict, torch.Tensor): return move_one_tensor(state_dict) res = {} @@ -126,100 +73,77 @@ def move_one_tensor(x): res[k] = move_one_tensor(v) return res - def fit(self, lightning_mod, data_loader): + def fit(self, model, loss_fn, data_loader): xla_env = torch_xla2.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True - jittable_mod = JittableModule(lightning_mod) - jax_params = self._shard_fsdp_style(jittable_mod.params) - jax_buffers = self._shard_fsdp_style(jittable_mod.buffers) - - @jax.checkpoint - def lightning_mod_loss( - weights: jax.Array, buffers: jax.Array, data: jax.Array, batch_id): - """returns loss""" - with jax.named_scope("Computing_loss"): - weights, buffers, data = torch_view((weights, buffers, data)) - # NOTE: these is needed because the original model - # did not register those as persistent buffer - with xla_env: - loss = jittable_mod.functional_call( - 'training_step', - weights, buffers, data, batch_id) - return jax_view(loss) + + model.to('jax') + jittable_mod = JittableModule(model) - jax_optimizer = optax.adamw(0.001) + # split the params to the n devices + jittable_mod.params = self._shard_fsdp_style(jittable_mod.params) + jittable_mod.buffers = self._shard_fsdp_style(jittable_mod.buffers) - opt_state = jax_optimizer.init(jax_params) - grad_fn = jax.value_and_grad(lightning_mod_loss) + def model_fn(weights, buffers, args): + return jittable_mod.functional_call('forward', weights, buffers, args) - opt_state_sharding = jax.tree_util.tree_map(lambda p : p.sharding, opt_state) - - print('Begining training') - - @functools.partial( - jax.jit, - donate_argnums=(0, 2), - ) - def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid): - print('Tracing inside of step') - with jax.named_scope("Computing_loss_and_grad"): - loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid) - with jax.named_scope("optimizer_updates"): - updates, opt_state = jax_optimizer.update( - grads, optimizer_state, jax_weights) - jax_weights = optax.apply_updates(jax_weights, updates) - return loss, jax_weights, opt_state - - total_param_size = 0 - for k, v in jax_params.items(): - total_param_size += v.size - - print('Total number of params: ', total_param_size) - - print('Start compiling') - start = time.perf_counter() - lowered = step.lower( - jax_params, jax_buffers, opt_state, - (jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding), - jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)), - 0 - ) - # print(lowered.as_text()) - print('program size:', len(lowered.as_text()) / 1e6, 'm chars') - step_compiled = lowered.compile() - end = time.perf_counter() - compile_time = end - start - print('End compiling', compile_time) + jax_optimizer = optax.adamw(0.001) + opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params) - for co in step_compiled.cost_analysis(): - print('flops counter:', co['flops']) + train_step = torch_xla2.train.make_train_step( + model_fn, loss_fn, jax_optimizer, + remat_policy=jax.checkpoint_policies.nothing_saveable, + mark_fsdp_sharding_axis='fsdp') + print('Begining training') s = time.perf_counter() jax.profiler.start_trace('/tmp/tensorboard') print('start training') min_loop_time = 10000 for i, item in enumerate(data_loader): - inputs, labels = sharded_device_put(jax_view(xla_env.to_xla(item)), - self.x_sharding) - print('INPUT shape', inputs.shape) + inputs, labels = item + # Move them to jax device + inputs = inputs.to('jax') + labels = labels.to('jax') + + # Shard them on batch dim for fsdp + inputs.apply_(sharded_device_put, self.x_sharding) + labels.apply_(sharded_device_put, self.x_sharding) + print('INPUT shape', inputs.shape) step_start = time.perf_counter() - loss, jax_params, opt_state = step_compiled( - jax_params, jax_buffers, opt_state, (inputs, labels), 0) - jax.block_until_ready((loss, jax_params)) + loss, jittable_mod.params, opt_state = train_step( + jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels) + # wait for iteration to finish to measure time + jax.block_until_ready((loss, jittable_mod.params)) step_end = time.perf_counter() print(i, 'loss', loss, 'step latency: ', step_end - step_start) loop_time = step_end - step_start min_loop_time = min(min_loop_time, loop_time) print('======') - if i >= 2: + if i >= 3: break jax.profiler.stop_trace() - return min_loop_time, compile_time + return min_loop_time + +def create_sharded_weights(state_dict, sharding): + res = {} + env = torch_xla2.default_env() + for name, weight_meta in state_dict.items(): + with jax.default_device(jax.devices('cpu')[0]): + weight_torch = torch.randn( + weight_meta.shape, + dtype=weight_meta.dtype) + # weight_jax is jax array + weight_jax = env.to_xla(weight_torch).jax() + res[name] = env.j2t_iso(jax.make_array_from_callback( + weight_jax.shape, sharding, lambda a: weight_jax[a] + )) + return res def fake_dataloader(size, seqlen, batch_size): @@ -232,33 +156,51 @@ def main( model_type='8B', batch_size=8, seqlen=2048, - mode='regular', + override_num_layers=-1, ): - logging.getLogger("jax").setLevel(logging.DEBUG) + torch_xla2.enable_globally() + #logging.getLogger("jax").setLevel(logging.DEBUG) print(f"Running with parameters {locals()}") - global SEQLEN - global BATCH - SEQLEN = seqlen - BATCH = batch_size mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', )) + sharding = jax.sharding.NamedSharding(mesh, P('fsdp')) + env = torch_xla2.default_env() - env.config.use_tpu_flash_attention = use_flash_attention - env.config.shmap_flash_attention = use_flash_attention + env.config.use_tpu_flash_attention = True + env.config.shmap_flash_attention = True + env._mesh = mesh # this is the mesh used by flash attention pallas kernel args = llama3_configs[model_type] - #with torch.device('meta'): - gpt = titan.Transformer(args) - - light_mod = Module(gpt) - light_mod.to(torch.bfloat16) - + # Note: torchtitan's upstream config did not specify this value + args.vocab_size = 128256 + if override_num_layers > 0: + args.n_layers = override_num_layers + + # Note: because a single device don't have enough HBM memory + # nor enough CPU memory to hold the parameters. We instantiate + # the model on meta then manually initialize then shard each param + torch.set_default_dtype(torch.bfloat16) + with torch.device('meta'): + gpt = titan.Transformer(args) + gpt.to(torch.bfloat16) + + state_dict = create_sharded_weights(gpt.state_dict(), sharding) + gpt.load_state_dict(state_dict, assign=True) + train_loader = fake_dataloader(10, seqlen, batch_size) + def loss_fn(logits, y): + num_tokens = logits.shape[-1] + logits = logits.reshape(-1, num_tokens) + y = y.reshape(-1) + return torch.nn.functional.cross_entropy( + logits, y) + with mesh: trainer = Trainer() return trainer.fit( - light_mod, + gpt, + loss_fn, train_loader ) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 35d69eb7326..ec609d93ac7 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -155,6 +155,14 @@ def device(self): def jax_device(self): return self._elem.device + def apply(self, jax_function, *args, **kwargs): + # Call a jax function on _elem + res = jax_function(self._elem, *args, **kwargs) + return self._env.j2t_iso(res) + + def apply_(self, jax_function, *args, **kwargs): + self._elem = jax_function(self._elem, *args, **kwargs) + def tolist(self): return self._elem.tolist() @@ -294,7 +302,7 @@ def get_as_jax_device(self, device: Any): if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): return None - if device in ('jax_cpu', 'cpu'): + if device == 'cpu': return jax.devices('cpu')[0] return jax.devices()[0] diff --git a/experimental/torch_xla2/torch_xla2/train.py b/experimental/torch_xla2/torch_xla2/train.py index d590dd3ae68..7dd378795a2 100644 --- a/experimental/torch_xla2/torch_xla2/train.py +++ b/experimental/torch_xla2/torch_xla2/train.py @@ -41,10 +41,13 @@ def make_train_step(model_fn, def loss(weights, buffers, args, label): # inputs are XLATensor with env, jax.named_scope('compute_loss'): if mark_fsdp_sharding_axis is not None: - args = (mark_sharding(args[0], P(mark_fsdp_sharding_axis)), *args[1:]) + args = mark_sharding( + args, + jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) res = model_fn(weights, buffers, args) if mark_fsdp_sharding_axis is not None: - res = mark_sharding(res, P(mark_fsdp_sharding_axis)) + res = mark_sharding(res, jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) + label = mark_sharding(label, jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) l = loss_fn(res, label) return l @@ -61,4 +64,4 @@ def step(weights, buffers, opt_state, args, label): #inputs are array weights = interop.call_jax(optax.apply_updates, weights, updates) return loss, weights, opt_state - return step \ No newline at end of file + return interop.jax_jit(step, {'donate_argnums': (0, 2)}) \ No newline at end of file