This document outlines the integration of the Gemma model into the EasyLM framework, including instructions for training, converting the model format, and serving the model with Gradio.
You can skip this step with downloading https://huggingface.co/beomi/gemma-ko-7b/resolve/flax-init/flax_model.msgpack
Firstly, concatenate all Flax model weights available at: Hugging Face - Gemma 7B.
Use the following example code to accomplish this:
from transformers import GemmaForCausalLM
model = GemmaForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto")
model.save_pretrained("./flax-concatted", max_shard_size="99GB")
This script generates a flax-concatted/flax_model.msgpack
file. We will utilize this .msgpack
file during the training process.
Execute the following command to upload the generated .msgpack
file to your GCS repository:
gsutil cp ./flax-concatted/flax_model.msgpack gs://YOUR_GCS_REPO_NAME
Adjust the paths for load_checkpoint
, train_dataset.json_dataset.path
, and logger.output_dir
within the train.sh
script to match your setup.
The provided example train.sh
script assumes training will be conducted on a TPUv4-64 pod slice.
Execute the training script to start the training process:
./train.sh
Download the streaming_train_state
file from your GCS repository using the following command:
gsutil cp gs://YOUR_GCS_REPO_NAME/.../streaming_train_state_80000 .
Note: The file name will either be streaming_train_state
or streaming_train_state_STEPNO
.
In the convert_easylm_stream_to_hf_safetensors.py
file, modify the path to the .stream
file accordingly:
# Modify this line
_, param = StreamingCheckpointer.load_trainstate_checkpoint(load_from='trainstate_params::/home/latheledusjp/streaming_train_state_80000')
Run the conversion script to convert the EasyLM model format to Hugging Face's format:
python convert_easylm_stream_to_hf_safetensors.py
Check the generated output files in the ./gemma-ko-8.5b-dev
directory.
The Flax-version of the weight file can be found in the
./flax-gemma-ko-8b
folder.
To serve the model using Gradio, follow these steps:
cd EasyLM/models/gemma
pip install -r serving_requirements.txt
./serve_test.sh
If you found EasyLM useful in your research or applications, please cite using the following BibTeX:
@software{geng2023easylm,
author = {Geng, Xinyang},
title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models},
month = March,
year = 2023,
url = {https://github.com/young-geng/EasyLM}
}
- The LLaMA implementation is from JAX_llama
- The JAX/Flax GPT-J and RoBERTa implementation are from transformers
- Most of the JAX utilities are from mlxu
- The codebase is heavily inspired by JAXSeq