Skip to content

Latest commit

 

History

History
57 lines (38 loc) · 1.73 KB

README.md

File metadata and controls

57 lines (38 loc) · 1.73 KB

Mixtral 8x7B

Install

# Install the latest xtuner
pip install -U 'xtuner[deepspeed]'

# Mixtral requires flash-attn
pip install flash-attn

# install the latest transformers
pip install -U transformers

QLoRA Fine-tune

QLoRA only need a single A100-80G

xtuner train mixtral_8x7b_instruct_qlora_oasst1_e3 --deepspeed deepspeed_zero2

Full Parameter Fine-tune

Full parameter fine-tune needs 16 A100-80G

slurm

Note: $PARTITION means the virtual partition of slurm.

srun -p $PARTITION --job-name=mixtral --nodes=2 --gres=gpu:8 --ntasks-per-node=8 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3 --launcher slurm

torchrun

Note: $NODE_0_ADDR means the ip address of the node_0 machine.

# excuete on node 0
NPROC_PER_NODE=8 NNODES=2 PORT=29600 ADDR=$NODE_0_ADDR NODE_RANK=0 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3

# excuete on node 1
NPROC_PER_NODE=8 NNODES=2 PORT=29600 ADDR=$NODE_0_ADDR NODE_RANK=1 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3

Speed

16 * A100 80G:

Model Sequence Length Use Varlen Attn Sequence Parallel World Size Tokens per Second
mixtral_8x7b 32k False 1 853.7
mixtral_8x7b 32k True 1 910.1
mixtral_8x7b 32k False 2 635.2
mixtral_8x7b 32k True 2 650.9