Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): support deepspeed #1849

Merged
merged 17 commits into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/aishell/s0/conf/ds_stage2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 0.0001,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 8,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 1e7,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e7,
"contiguous_gradients" : true
},
"activation_checkpointing": {
"partition_activations": false,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": true
},
"flops_profiler": {
"enabled": false,
"profile_step": 100,
"module_depth": -1,
"top_modules": 1,
"detailed": true,
"output_file": null
},
"tensorboard": {
"enabled": true,
"output_path": "tensorboard/ds_logs/",
"job_name": "deepspeed"
}
}
90 changes: 90 additions & 0 deletions examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 2048 # dimension of attention
attention_heads: 16
linear_units: 8192 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d8 # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 8
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 16
linear_units: 8192
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
spec_sub: true
spec_sub_conf:
num_t_sub: 3
max_t: 30
spec_trim: false
spec_trim_conf:
max_t: 50
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 100
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
93 changes: 65 additions & 28 deletions examples/aishell/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ train_config=conf/train_conformer.yaml
cmvn=true
dir=exp/conformer
checkpoint=
num_workers=8
prefetch=500

# use average_checkpoint will get better result
average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=30
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"

deepspeed=false
deepspeed_config=conf/ds_stage2.json
deepspeed_save_states="model_only"

. tools/parse_options.sh || exit 1;

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
Expand Down Expand Up @@ -116,11 +122,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# You have to rm `INIT_FILE` manually when you resume or restart a
# multi-machine training.
INIT_FILE=$dir/ddp_init
rm -f ${INIT_FILE} # remove previous INIT_FILE
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="gloo"
dist_backend="nccl"
world_size=`expr $num_gpus \* $num_nodes`
echo "total gpus is: $world_size"
cmvn_opts=
Expand All @@ -130,30 +137,60 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
# and export.
for ((i = 0; i < $num_gpus; ++i)); do
{
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
# Rank of each gpu/process used for knowing whether it is
# the master of a worker.
rank=`expr $node_rank \* $num_gpus + $i`
python wenet/bin/train.py --gpu $gpu_id \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.init_method $init_method \
--ddp.world_size $world_size \
--ddp.rank $rank \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory
} &
done
wait
if [ ${deepspeed} == true ]; then
echo "using deepspeed"
# NOTE(xcsong): deepspeed fails with gloo, see
# https://github.com/microsoft/DeepSpeed/issues/2818
dist_backend="nccl"
[ ! -f data/$train_set/data.list.filter ] && \
python tools/filter_uneven_data.py data/$train_set/data.list \
$data_type $num_gpus $num_utts_per_shard data/$train_set/data.list.filter
deepspeed --include localhost:$CUDA_VISIBLE_DEVICES \
wenet/bin/train.py \
--deepspeed \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states} \
--ddp.dist_backend $dist_backend \
--ddp.init_method $init_method \
--data_type $data_type \
--config $train_config \
--symbol_table data/dict/lang_char.txt \
--train_data data/$train_set/data.list.filter \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
$cmvn_opts \
--pin_memory
else
echo "using torch ddp"
for ((i = 0; i < $num_gpus; ++i)); do
{
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
# Rank of each gpu/process used for knowing whether it is
# the master of a worker.
rank=`expr $node_rank \* $num_gpus + $i`
python wenet/bin/train.py --gpu $gpu_id \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.init_method $init_method \
--ddp.world_size $world_size \
--ddp.rank $rank \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
$cmvn_opts \
--pin_memory
} &
done
wait
fi
fi

if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
Expand All @@ -171,8 +208,8 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# non-streaming model. The default value is -1, which is full chunk
# for non-streaming inference.
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
ctc_weight=0.3
reverse_weight=0.5
for mode in ${decode_modes}; do
{
test_dir=$dir/test_${mode}
Expand Down Expand Up @@ -298,4 +335,4 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
# --lfmmi_dir data/local/lfmmi

# 9.3 Run HLG decode from stage 8.2
fi
fi
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pycodestyle==2.6.0
pyflakes==2.2.0
torch==1.13.0
torchaudio==0.13.0
deepspeed
51 changes: 51 additions & 0 deletions tools/filter_uneven_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-04-27] <[email protected], Xingchen Song>

import os
import random
import tarfile

random.seed(1024)

# parse arg from command line
datalist = os.sys.argv[1]
datatype = os.sys.argv[2]
num_gpus = int(os.sys.argv[3])
num_samples_per_tar = int(os.sys.argv[4]) # only used in shard mode
new_datalist = os.sys.argv[5]

assert datatype in ["shard", "raw"]


filtered_list = []
with open(datalist, "r") as f:
lines = f.readlines()
lines = [l.strip() for l in lines]
if datatype == "raw":
valid_num = len(lines) // num_gpus * num_gpus
random.shuffle(lines)
filtered_list = lines[:valid_num]
else:
for line in lines:
cnt = 0
with open(line, "rb") as tar:
stream = tarfile.open(fileobj=tar, mode="r|*")
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'txt':
cnt += 1
if cnt == num_samples_per_tar:
filtered_list.append(line)
valid_num = len(filtered_list) // num_gpus * num_gpus
random.shuffle(filtered_list)
filtered_list = filtered_list[:valid_num]
filtered_list.sort()
print("before filter: {} after filter: {}".format(len(lines), len(filtered_list)))

with open(new_datalist, "w") as f:
for line in filtered_list:
f.writelines("{}\n".format(line))
Loading