-
Notifications
You must be signed in to change notification settings - Fork 491
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
150 additions
and
147 deletions.
There are no files selected for viewing
35 changes: 35 additions & 0 deletions
35
experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# syntax=docker/dockerfile:experimental | ||
# Use Python 3.10 as the base image | ||
FROM python:3.10-slim-bullseye | ||
|
||
# Install system dependencies | ||
RUN apt-get update && apt-get upgrade -y | ||
RUN apt-get update && apt-get install -y curl gnupg | ||
|
||
# Add the Google Cloud SDK package repository | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
|
||
# Install the Google Cloud SDK | ||
RUN apt-get update && apt-get install -y google-cloud-sdk git | ||
|
||
# Set the default Python version to 3.10 | ||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 | ||
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | ||
RUN pip install optax fire tensorflow tensorboard-plugin-profile | ||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | ||
|
||
WORKDIR / | ||
RUN git clone https://github.com/pytorch/torchtitan.git | ||
WORKDIR /torchtitan | ||
RUN pip install -r requirements.txt | ||
RUN pip install . | ||
|
||
WORKDIR / | ||
RUN git clone https://github.com/pytorch/xla.git | ||
WORKDIR xla/experimental/torch_xla2 | ||
RUN git checkout hanq_hybrid_mesh | ||
RUN pip install -e . | ||
|
||
ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] | ||
CMD ["--batch_size=8", "--seqlen=2048"] |
15 changes: 15 additions & 0 deletions
15
experimental/torch_xla2/examples/train_llama_torchtitan/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
Training based on torchtitan llama model | ||
==================================== | ||
|
||
```bash | ||
python train_llama.py | ||
``` | ||
|
||
|
||
|
||
## Detailed numbers | ||
|
||
### v5p-8 | ||
|
||
seqlen = 8192 | ||
bs = 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.