From 4aa350cc7f68cf3b38f48e928be55932ed0cb9f0 Mon Sep 17 00:00:00 2001 From: zhutong Date: Wed, 22 Nov 2023 15:10:38 +0800 Subject: [PATCH] - add final data portion of sheared llama - add `gate_network_type`, `moe_calculator_score_scale_factor`, and update `prob_map` arguments in config - add exec scripts --- .../get_layer_wise_score_scale_factor.sh | 8 +- .../baseline_112gpus.sh | 164 +++++++++++++++++ .../baseline_112gpus_scale2.0.sh | 165 +++++++++++++++++ .../baseline_112gpus_sheared_llama_portion.sh | 165 +++++++++++++++++ .../sheared_llama_112gpus_100B.sh | 161 +++++++++++++++++ smoe/data/dynamic_selection.py | 12 +- .../get_layer_wise_score_scale_factor.py | 22 ++- .../analysis/scale_factor_simulation.py | 169 +++++++++++++----- smoe/entrypoint/cpt/cpt_fpt.py | 67 +++---- smoe/modules/moe/moe_gates.py | 1 + smoe/utils/config.py | 19 +- tests/data/test_streaming.py | 42 ++--- 12 files changed, 870 insertions(+), 125 deletions(-) create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh create mode 100644 scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh create mode 100644 scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh diff --git a/scripts/analysis/get_layer_wise_score_scale_factor.sh b/scripts/analysis/get_layer_wise_score_scale_factor.sh index 82893ca..92cda61 100644 --- a/scripts/analysis/get_layer_wise_score_scale_factor.sh +++ b/scripts/analysis/get_layer_wise_score_scale_factor.sh @@ -2,13 +2,15 @@ # llama_7B llama_13B llama_30B llama_base llama_3B # llama2_7B llama2_13B llama2_30B llama2_base -llama_size="llama_13B" -model_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons +llama_size="llama2_7B" +# model_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons +model_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B data_begin_index=0 data_end_index=500 batch_size=8 -block_size=2048 +# block_size=2048 +block_size=4096 #save_folder=${llama_size}_dense save_folder=${llama_size}_moe_trained diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh new file mode 100644 index 0000000..1929a77 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus.sh @@ -0,0 +1,164 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, one linear layer gate" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps 100 \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh new file mode 100644 index 0000000..025a72a --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_scale2.0.sh @@ -0,0 +1,165 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --moe_calculator_score_scale_factor 2.0 \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps 100 \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh new file mode 100644 index 0000000..a74b1d1 --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/baseline_112gpus_sheared_llama_portion.sh @@ -0,0 +1,165 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16, mlp gate, sheared llama data portion" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama2_7B-16Select4-up_proj-Scale4.0 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=2e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="200*10^9" + warmup_tokens="15*10^8" + # warmup_tokens="0" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc) + echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps" + eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --prob_map "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps 100 \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh new file mode 100644 index 0000000..fe2af8f --- /dev/null +++ b/scripts/cpt/dynamic_data_selection/sheared_llama_112gpus_100B.sh @@ -0,0 +1,161 @@ +#!/usr/bin/bash + +#SBATCH --job-name=cpt-llama2_random_scale4_112gpus_dynamic_data +#SBATCH --output=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log +#SBATCH --error=/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=64 +#SBATCH --mem=0 + +#SBATCH --nodes=14 +#SBATCH --gres=gpu:8 +#SBATCH --quotatype=reserved +#SBATCH -x SH-IDCA1404-10-140-54-36 + +# reserved spot + +source ~/anaconda3/bin/activate smoe + +{ + num_nodes=14 # should match with --nodes + num_gpu_per_node=8 # should match with --gres + + # #cpu/#num_gpu_per_node + export OMP_NUM_THREADS=32 + export LOGLEVEL=INFO + # export NCCL_DEBUG=INFO + # export TORCH_DISTRIBUTED_DEBUG=DETAIL + # export TORCH_SHOW_CPP_STACKTRACES=1 + # export CUDA_LAUNCH_BLOCKING=1 + + ############################################################## + ############### LLAMA 7B Moefication 16Experts ############### + # comment="llama 7B residual, gradient, 2 + 2/14 | soft residual 2.0 | soft moe 2.0 | GPU num 1, per-device bs 64, lr 1e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + ######## LLAMA 2 7B 16 Experts all kinds of ablations ######## + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, moefication gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + model_type="llama_moe" + comment="llama 2 7B, random 4/16" + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/llama_moe_cpt/llama2_7b/random_scale4/checkpoint-6800 + + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual hard, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual plain soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 8.0, moe soft 8.0 | GPU num 16, per-device bs 32, lr 3e-4" + # comment="llama 2 7B, residual 2, share gradient 2/14 | residual learn soft 2.0, moe soft 2.0 | GPU num 16, per-device bs 32, lr 3e-4" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEResidualForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-14Select2-2Residuals-688Neurons-Share + + ############################################################## + + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized + + lr=1e-4 + final_lr_portion=0.1 + per_device_train_batch_size=8 + per_device_eval_batch_size=8 + gradient_accumulation_steps=4 + block_size=4096 + num_tokens="100*10^9" + eval_tokens="2.5*10^9" + seed=1227 + deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + + num_selects=4 + + max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) + max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc) + echo "max_steps: $max_steps" + echo "max_train_samples: $max_train_samples" + global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc) + echo "global batch size: $global_bs" + tokens_per_batch=$(echo "$global_bs * $block_size" | bc) + echo "#tokens/batch: $tokens_per_batch" + eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc) + echo "eval interval (tokens): $eval_tokens, steps: $eval_steps" + + data_cache=resources/cache + base_dir="/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_scale4_112gpus_dynamic_data" + output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID + mkdir -p $output_dir + echo "output_dir: $output_dir" + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo $comment > $output_dir/comment.txt + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + echo "Node: $head_node" + echo "Node IP: $head_node_ip" + echo "Node list: $SLURM_JOB_NODELIS" + + srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + smoe/entrypoint/cpt/cpt_fpt.py \ + --dynamic_data_selection "sheared_llama" \ + --deepspeed ${deepspeed_config_file} \ + --model_name_or_path ${pretrained_model} \ + --model_type ${model_type} \ + --tokenizer_name_or_path ${tokenizer_path} \ + --dataset_dir ${dataset_dir} \ + --data_cache_dir ${data_cache} \ + --validation_dir ${validation_dir} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --do_train \ + --evaluation_strategy steps \ + --eval_steps ${eval_steps} \ + --seed ${seed} \ + --bf16 \ + --num_train_epochs 1 \ + --final_lr_portion ${final_lr_portion} \ + --optim adamw_torch \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate ${lr} \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --warmup_steps 0 \ + --max_steps ${max_steps} \ + --max_train_samples ${max_train_samples} \ + --save_strategy steps \ + --save_total_limit 1 \ + --save_steps ${eval_steps} \ + --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --block_size ${block_size} \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 3600 \ + --ddp_find_unused_parameters False \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --logging_first_step True \ + --logging_strategy steps \ + --logging_steps 5 \ + --log_level info \ + --log_level_replica warning \ + --log_on_each_node False \ + --report_to none \ + --gate_type "TopKBalancedNoisyGate" \ + --calculator_type "UniversalCalculator" \ + --num_selects ${num_selects} +} diff --git a/smoe/data/dynamic_selection.py b/smoe/data/dynamic_selection.py index f9049c2..8a08e42 100644 --- a/smoe/data/dynamic_selection.py +++ b/smoe/data/dynamic_selection.py @@ -11,7 +11,17 @@ "en_stack": 0.02, } -LLAMA_DATA_PORTION_AVG = { +SHEAREDLLAMA_DATA_PORTION = { + "en_cc": 0.361, + "en_c4": 0.492, + "github": 0.008, + "en_wikipedia": 0.031, + "en_book": 0.091, + "en_arxiv": 0.007, + "en_stack": 0.01, +} + +AVERAGE_SLIMPAJAMA_DATA_PORTION = { "en_cc": 1 / 7, "en_c4": 1 / 7, "github": 1 / 7, diff --git a/smoe/entrypoint/analysis/get_layer_wise_score_scale_factor.py b/smoe/entrypoint/analysis/get_layer_wise_score_scale_factor.py index 0498cd4..38205ba 100644 --- a/smoe/entrypoint/analysis/get_layer_wise_score_scale_factor.py +++ b/smoe/entrypoint/analysis/get_layer_wise_score_scale_factor.py @@ -4,21 +4,24 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import LlamaTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer from smoe.data.collate_fn import tensor_dict_cat_collator from smoe.data.datasets_moefication import LineByLineJsonlTextDataset from smoe.models.llama_moe import LlamaMoEForCausalLM +from smoe.utils.model_operation.modify_llama_model import ( + llama_with_hidden_states_scale_recording, +) from smoe.utils.model_operation.modify_llama_moe_model import ( llama_moe_with_hidden_states_scale_recording_early_stop, ) -from smoe.utils.string_operation import str2bool # fmt: off if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--tokenizer_path', type=str) parser.add_argument('--model_path', type=str) + parser.add_argument('--model_type', type=str, choices=["llama", "llama_moe"], default="llama_moe") parser.add_argument('--target_scale_file_path', type=str) parser.add_argument('--data_path', type=str) parser.add_argument('--save_path', type=str) @@ -59,8 +62,16 @@ """load model""" print("Loading llama model...", flush=True) - model = LlamaMoEForCausalLM.from_pretrained(args.model_path).model - model.set_moe_calculator_score_scale_factor(1.0) + new_forward = None + if args.model_type == "llama_moe": + model = LlamaMoEForCausalLM.from_pretrained(args.model_path).model + new_forward = llama_moe_with_hidden_states_scale_recording_early_stop + model.set_moe_calculator_score_scale_factor(1.0) + elif args.model_type == "llama": + model = LlamaForCausalLM.from_pretrained(args.model_path).model + new_forward = llama_with_hidden_states_scale_recording + else: + raise ValueError("Unknown model type: " + args.model_type) """calculate scale factor layer by layer""" print("Start evaluation...", flush=True) @@ -70,7 +81,8 @@ model.half() model.eval() for layer_index in tqdm(range(model.config.num_hidden_layers), desc="forward by layer", leave=True): - model = llama_moe_with_hidden_states_scale_recording_early_stop(model, early_stop_layer=layer_index) + # model = llama_moe_with_hidden_states_scale_recording_early_stop(model, early_stop_layer=layer_index) + model = new_forward(model) iter_train = iter(data_loader) for step in tqdm(range(len(data_loader)), desc="forward step", leave=False): diff --git a/smoe/entrypoint/analysis/scale_factor_simulation.py b/smoe/entrypoint/analysis/scale_factor_simulation.py index c5e5512..5d0a5b7 100644 --- a/smoe/entrypoint/analysis/scale_factor_simulation.py +++ b/smoe/entrypoint/analysis/scale_factor_simulation.py @@ -6,9 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - -xs = [] -ys = [] +from scipy import stats def kaiming_init(*size): @@ -17,48 +15,123 @@ def kaiming_init(*size): return tensor -init_func = torch.randn -# init_func = kaiming_init - -max_num_experts = 32 -intermediate_size = 11008 -hidden_size = 4096 -base = None -for k in range(1, max_num_experts + 1): - mid = int(intermediate_size / k) - distances = [] - for _ in range(10): - gate = init_func(hidden_size, mid) - up = init_func(hidden_size, mid) - down = init_func(mid, hidden_size) - - x = init_func(1, hidden_size) - # y = x @ l1 @ l2 - y = (F.silu(x @ gate) * (x @ up)) @ down - - dist = (x - y).abs().sum() - # dist = (x - y).pow(2).sum() - distances.append(dist.item()) - - xs.append(k) - if base is None and k == 1: - base = sts.mean(distances) - ys.append(base / sts.mean(distances)) - print(xs[-1], ys[-1]) - - -plt.plot(xs, ys, label="simulation") -plt.plot(xs, np.sqrt(xs), label="sqrt", linestyle="dashed") -plt.legend() -plt.xlabel("#Experts") -plt.ylabel("Scale Factor") -plt.grid(True, zorder=-1) - -# plt.title("SwiGLU Kaiming Normal Initialization (fan_out)") -# plt.savefig("swiglu_kaiming_fan_out_1024.png") - -out_dir = Path("results/analysis_scale_factor") -out_dir.mkdir(exist_ok=True, parents=True) -plt.title("Normal Initialization") -plt.savefig(out_dir / "normal.png") -# plt.savefig(out_dir / "normal_dropout_rescale.png") +def simulation(): + xs = [] + ys = [] + + init_func = torch.randn + # init_func = kaiming_init + + max_num_experts = 32 + intermediate_size = 11008 + hidden_size = 4096 + base = None + for k in range(1, max_num_experts + 1): + mid = int(intermediate_size / k) + distances = [] + for _ in range(10): + gate = init_func(hidden_size, mid) + up = init_func(hidden_size, mid) + down = init_func(mid, hidden_size) + + x = init_func(1, hidden_size) + # y = x @ l1 @ l2 + y = (F.silu(x @ gate) * (x @ up)) @ down + + dist = (x - y).abs().sum() + # dist = (x - y).pow(2).sum() + distances.append(dist.item()) + + xs.append(k) + if base is None and k == 1: + base = sts.mean(distances) + ys.append(base / sts.mean(distances)) + print(xs[-1], ys[-1]) + + plt.plot(xs, ys, label="simulation") + plt.plot(xs, np.sqrt(xs), label="sqrt", linestyle="dashed") + plt.legend() + plt.xlabel("#Experts") + plt.ylabel("Scale Factor") + plt.grid(True, zorder=-1) + + # plt.title("SwiGLU Kaiming Normal Initialization (fan_out)") + # plt.savefig("swiglu_kaiming_fan_out_1024.png") + + out_dir = Path("results/analysis_scale_factor") + out_dir.mkdir(exist_ok=True, parents=True) + plt.title("Normal Initialization") + plt.savefig(out_dir / "normal.png") + # plt.savefig(out_dir / "normal_dropout_rescale.png") + + +def line_graph(vals: list, label=None): + plt.hist(vals, bins=100, density=True) + xt = plt.xticks()[0] + xmin, xmax = min(xt), max(xt) + lnspc = np.linspace(xmin, xmax, len(vals)) + ab, bb, cb, db = stats.beta.fit(vals) + pdf_beta = stats.beta.pdf(lnspc, ab, bb, cb, db) + plt.plot(lnspc, pdf_beta, label=label) + return lnspc, pdf_beta + + +def stat_plot(): + dense_hidden = ( + torch.load("results/analysis_scale_factor/hidden_states.pt") + .detach() + .numpy() + .flatten() + ) + dense_hidden_x, dense_hidden_y = line_graph(dense_hidden, label="dense_hidden") + dense_residual = ( + torch.load("results/analysis_scale_factor/residual.pt") + .detach() + .numpy() + .flatten() + ) + dense_residual_x, dense_residual_y = line_graph( + dense_residual, label="dense_residual" + ) + moe_hidden = ( + torch.load("results/analysis_scale_factor/moe_hidden_states.pt") + .detach() + .numpy() + .flatten() + ) + moe_hidden_x, moe_hidden_y = line_graph(moe_hidden, label="moe_hidden") + moe_residual = ( + torch.load("results/analysis_scale_factor/moe_residual.pt") + .detach() + .numpy() + .flatten() + ) + moe_residual_x, moe_residual_y = line_graph(moe_residual, label="moe_residual") + plt.xlim(-0.3, 0.3) + plt.legend() + plt.savefig("results/analysis_scale_factor/hist.png") + plt.close() + + plt.plot(dense_hidden_x, dense_hidden_y, label="dense_hidden") + plt.fill_between( + dense_hidden_x, dense_hidden_y, [0] * len(dense_hidden_x), alpha=0.1 + ) + plt.plot(dense_residual_x, dense_residual_y, label="dense_residual") + plt.fill_between( + dense_residual_x, dense_residual_y, [0] * len(dense_residual_x), alpha=0.1 + ) + plt.plot(moe_hidden_x, moe_hidden_y, label="moe_hidden") + plt.fill_between(moe_hidden_x, moe_hidden_y, [0] * len(moe_hidden_x), alpha=0.1) + plt.plot(moe_residual_x, moe_residual_y, label="moe_residual") + plt.fill_between( + moe_residual_x, moe_residual_y, [0] * len(moe_residual_x), alpha=0.1 + ) + plt.xlim(-0.3, 0.3) + plt.legend() + plt.show() + plt.savefig("results/analysis_scale_factor/comparison.png") + + +if __name__ == "__main__": + # simulation() + stat_plot() diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index 23a79b5..35a98e8 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -23,6 +23,11 @@ from smoe.callbacks.save_model import SchedulerStateCallback from smoe.callbacks.tensorboard import EnhancedTensorboardCallback from smoe.data.collate_fn import fault_tolerance_data_collator +from smoe.data.dynamic_selection import ( + AVERAGE_SLIMPAJAMA_DATA_PORTION, + LLAMA_DATA_PORTION, + SHEAREDLLAMA_DATA_PORTION, +) from smoe.data.streaming import CachedJsonlDataset, SubDirWeightedPackedJsonlDataset from smoe.metrics.preprocess import logits_argmax from smoe.models.llama_moe.configuration_llama_moe import LlamaMoEConfig @@ -133,6 +138,11 @@ def main(): "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, + "gate_type": model_args.gate_type, + "calculator_type": model_args.calculator_type, + "num_selects": model_args.num_selects, + "gate_network": model_args.gate_network_type, + "score_scale_factor": model_args.moe_calculator_score_scale_factor, } ConfigClass = AutoConfig if model_args.config_name == "llama_moe" or model_args.model_type == "llama_moe": @@ -160,15 +170,6 @@ def main(): if training_args.gradient_checkpointing: config.use_cache = False - config: LlamaMoEConfig - config.gate_type = model_args.gate_type - config.calculator_type = model_args.calculator_type - config.num_selects = model_args.num_selects - # config.add_weight_norm = True - # config.score_scale_factor = 4.0 - # config.gate_use_balance = False - # config.gate_add_noise = False - # zhutong: this is for debug usage only if training_args.debug_mode: config.num_hidden_layers = 2 @@ -215,37 +216,16 @@ def main(): ) block_size = min(data_args.block_size, tokenizer.model_max_length) - if data_args.prob_map is None: - # slimpajama samples openllama-3B tokenized - # data_args.prob_map = { - # "cc": 0.67, - # "wikipedia": 0.33, - # } - - # redpajama - data_args.prob_map = { - "en_cc": 0.67, - "en_c4": 0.15, - "github": 0.045, - "en_wikipedia": 0.045, - "en_book": 0.045, - "en_arxiv": 0.025, - "en_stack": 0.02, - } - # data_args.prob_map = { - # "en_cc_v2": 0.67, - # "en_c4_v2": 0.15, - # "github_v2": 0.045, - # "en_wikipedia": 0.045, - # "en_book": 0.045, - # "en_arxiv": 0.025, - # "en_stack": 0.02, - # } + prob_map = LLAMA_DATA_PORTION + if data_args.prob_map == "uniform": + prob_map = AVERAGE_SLIMPAJAMA_DATA_PORTION + elif data_args.prob_map == "sheared_llama": + prob_map = SHEAREDLLAMA_DATA_PORTION with training_args.main_process_first(desc="dataset map tokenization and grouping"): lm_datasets = SubDirWeightedPackedJsonlDataset( data_args.dataset_dir, - prob_map=data_args.prob_map, + prob_map=prob_map, seed=training_args.seed, block_size=data_args.block_size, ) @@ -341,12 +321,15 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - if hasattr(model, "set_moe_calculator_score_scale_factor"): - # update config for checkpoint retrival - # model.set_moe_gate_balance_loss_weight(0.1) - model.set_moe_calculator_score_scale_factor(4.0) - # model.set_moe_calculator_score_scale_factor(1.0) - model.update_config() + # if hasattr(model, "set_moe_calculator_score_scale_factor"): + # # update config for checkpoint retrival + # # model.set_moe_gate_balance_loss_weight(0.1) + # # model.set_moe_calculator_score_scale_factor(4.0) + # model.set_moe_calculator_score_scale_factor( + # model_args.moe_calculator_score_scale_factor + # ) + # # model.set_moe_calculator_score_scale_factor(1.0) + # model.update_config() model_vocab_size = model.get_output_embeddings().weight.size(0) if model_vocab_size != len(tokenizer): diff --git a/smoe/modules/moe/moe_gates.py b/smoe/modules/moe/moe_gates.py index aeef17e..c072113 100644 --- a/smoe/modules/moe/moe_gates.py +++ b/smoe/modules/moe/moe_gates.py @@ -232,6 +232,7 @@ def __init__( self.gate_network_type = gate_network self.gate_network = get_gate_network(gate_network, input_size, num_experts) + # self.gate_network = get_gate_network("linear", input_size, num_experts) self.use_softmax = use_softmax self.softmax = nn.Softmax(1) diff --git a/smoe/utils/config.py b/smoe/utils/config.py index 525c89e..abebac6 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -125,6 +125,12 @@ class ModelArguments: "choices": ["auto", "bfloat16", "float16", "float32"], }, ) + gate_network_type: Literal["mlp", "linear"] = field( + default="mlp", + metadata={ + "help": "The type of gate network, should be one of `mlp` and `linear`" + }, + ) gate_type: Literal["TopKBalancedNoisyGate", "SwitchBalancedGate"] = field( default="TopKBalancedNoisyGate", metadata={ @@ -139,6 +145,10 @@ class ModelArguments: "help": "The type of gate calculator, should be one of `UniversalCalculator` and `SwitchDropTokenCalculator`" }, ) + moe_calculator_score_scale_factor: float = field( + default=4.0, + metadata={"help": "scale factor for the calculator to be multiplied"}, + ) num_selects: int = field( default=4, metadata={"help": "The number of experts to be selected"} ) @@ -247,13 +257,10 @@ class DataArguments: data_cache_dir: Optional[str] = field( default="./", metadata={"help": "The datasets processed stored"} ) - prob_map: Optional[dict[str, float]] = field( - default=None, + prob_map: Optional[str] = field( + default="llama", metadata={ - "help": ( - 'data type to sampling probabilities. e.g. {"commoncrawl": 0.67, "c4":' - " 0.15}" - ) + "help": ("data portion. choices in ['llama', 'uniform', 'sheared_llama']") }, ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 8d458a3..2e71d81 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -117,7 +117,7 @@ def test_weighted_streaming_loader(): "en_arxiv": 0.025, "en_stack": 0.02, } - num_test_case = 2 + num_test_case = 20000 block_size = 2048 bsz = 1 @@ -135,25 +135,27 @@ def test_weighted_streaming_loader(): pin_memory=False, ) loader = ac.prepare_data_loader(loader) - - for batch_idx, batch in enumerate(loader): - if batch_idx == 0: - print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") - if num_test_case <= 0: - break - assert len(batch["input_ids"]) == bsz - # print( - # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - # ) - # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size - print(loader.dataset.prob_map) - num_test_case -= 1 - lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) - # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) - print(loader.dataset.prob_map) - # print( - # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" - # ) + print(type(loader)) + print(loader.sampler, type(loader.sampler)) + + # for batch_idx, batch in enumerate(loader): + # if batch_idx == 0: + # print(f"RANK {ac.process_index}/{ac.num_processes} - {batch}") + # if num_test_case <= 0: + # break + # assert len(batch["input_ids"]) == bsz + # # print( + # # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # # ) + # # assert sum(loader.dataset.consumed_tokens.values()) == (batch_idx + 1) * block_size + # print(loader.dataset.prob_map) + # num_test_case -= 1 + # lm_datasets.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + # # loader.dataset.update_existed_prob_map({"en_cc": 0.5, "en_c4": 0.5}) + # print(loader.dataset.prob_map) + # # print( + # # f"RANK {ac.process_index}/{ac.num_processes} - {loader.dataset.consumed_tokens} SUM: {sum(loader.dataset.consumed_tokens.values())}, Expected: {(batch_idx + 1) * bsz * block_size}" + # # ) def test_skip_tokens():