-
Notifications
You must be signed in to change notification settings - Fork 491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
training UX: automatic generating make_train_step #8495
Open
qihqi
wants to merge
9
commits into
master
Choose a base branch
from
hanq_train
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e23ad91
training UX: automatic generating make_train_step
qihqi 55e191c
Add readme and docker file
qihqi 50502fd
commit Scan
qihqi 017b4a3
checkpoint on v6e
qihqi 0d4a029
commit changes
qihqi 45c6765
Move sharding logic inside of the model function
qihqi b20f9d5
train update
qihqi 5680121
update from v6e examples
qihqi c7f3893
misc fixes
qihqi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# syntax=docker/dockerfile:experimental | ||
# Use Python 3.10 as the base image | ||
FROM python:3.10-slim-bullseye | ||
|
||
# Install system dependencies | ||
RUN apt-get update && apt-get upgrade -y | ||
RUN apt-get update && apt-get install -y curl gnupg | ||
|
||
# Add the Google Cloud SDK package repository | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
|
||
# Install the Google Cloud SDK | ||
RUN apt-get update && apt-get install -y google-cloud-sdk git | ||
|
||
# Set the default Python version to 3.10 | ||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 | ||
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | ||
RUN pip install optax fire tensorflow tensorboard-plugin-profile | ||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | ||
|
||
WORKDIR / | ||
RUN git clone https://github.com/pytorch/torchtitan.git | ||
WORKDIR /torchtitan | ||
RUN pip install -r requirements.txt | ||
RUN pip install . | ||
|
||
WORKDIR / | ||
RUN git clone https://github.com/pytorch/xla.git | ||
WORKDIR xla/experimental/torch_xla2 | ||
RUN pip install -e . | ||
|
||
ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] | ||
CMD ["--batch_size=8", "--seqlen=2048"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
donate_argnums
here imply that input buffers are donated to outputs? The(0, 2)
is pretty cryptic to me. Consider commenting on their meaning.Or better, maybe this could be handled internally? We could jit the function inside
make_train_step
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. The current issue is that sometimes I want to print out the stablehlo for inspection. So need to make the jax_jit'd object also to store the jax function. I'll followup.