-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(train): support deepspeed (#1849)
* feat(train): support deepspeed * feat(train): fix bug * feat(train): enable deepspeed in run.sh * feat(train): fix step bug for tensorboard logging * feat(train): recover cv log tensorboard logging * feat(train): make save_states configurable * feat(deepspeed): Support fp16/bf16 and deepspedCPUadam+customLRscheduler * feat(deepspeed): fix lint * feat(deepspeed): fix lint * feat(deepspeed): update stage2 config * feat(deepspeed): avoid re-generate filtered list if exists * feat(deepspeed): add 1.8B model * feat(deepspeed): make workers&prefetch configurable * feat(deepspeed): refine comment * feat(deepspeed): fix saving yaml * feat(deepspeed): refine if-else * feat(deepspeed): refine if-else
- Loading branch information
1 parent
ac9a261
commit 2505d18
Showing
7 changed files
with
480 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 1, | ||
"gradient_accumulation_steps": 1, | ||
"steps_per_print": 100, | ||
"gradient_clipping": 0.0001, | ||
"fp16": { | ||
"enabled": false, | ||
"auto_cast": false, | ||
"loss_scale": 0, | ||
"initial_scale_power": 8, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"bf16": { | ||
"enabled": false | ||
}, | ||
"zero_force_ds_cpu_optimizer": false, | ||
"zero_optimization": { | ||
"stage": 2, | ||
"offload_optimizer": { | ||
"device": "none", | ||
"pin_memory": true | ||
}, | ||
"offload_param": { | ||
"device": "none", | ||
"pin_memory": true | ||
}, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 1e7, | ||
"overlap_comm": true, | ||
"reduce_scatter": true, | ||
"reduce_bucket_size": 1e7, | ||
"contiguous_gradients" : true | ||
}, | ||
"activation_checkpointing": { | ||
"partition_activations": false, | ||
"cpu_checkpointing": false, | ||
"contiguous_memory_optimization": false, | ||
"number_checkpoints": null, | ||
"synchronize_checkpoint_boundary": false, | ||
"profile": true | ||
}, | ||
"flops_profiler": { | ||
"enabled": false, | ||
"profile_step": 100, | ||
"module_depth": -1, | ||
"top_modules": 1, | ||
"detailed": true, | ||
"output_file": null | ||
}, | ||
"tensorboard": { | ||
"enabled": true, | ||
"output_path": "tensorboard/ds_logs/", | ||
"job_name": "deepspeed" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: conformer | ||
encoder_conf: | ||
output_size: 2048 # dimension of attention | ||
attention_heads: 16 | ||
linear_units: 8192 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: conv2d8 # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 8 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
causal: true | ||
use_dynamic_chunk: true | ||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster | ||
use_dynamic_left_chunk: false | ||
|
||
# decoder related | ||
decoder: bitransformer | ||
decoder_conf: | ||
attention_heads: 16 | ||
linear_units: 8192 | ||
num_blocks: 3 | ||
r_num_blocks: 3 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.1 | ||
src_attention_dropout_rate: 0.1 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
reverse_weight: 0.3 | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 40960 | ||
min_length: 0 | ||
token_max_length: 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 1.0 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
spec_sub: true | ||
spec_sub_conf: | ||
num_t_sub: 3 | ||
max_t: 30 | ||
spec_trim: false | ||
spec_trim_conf: | ||
max_t: 50 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 16 | ||
|
||
grad_clip: 5 | ||
accum_grad: 1 | ||
max_epoch: 100 | ||
log_interval: 100 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.001 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 25000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ pycodestyle==2.6.0 | |
pyflakes==2.2.0 | ||
torch==1.13.0 | ||
torchaudio==0.13.0 | ||
deepspeed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
# Copyright [2023-04-27] <[email protected], Xingchen Song> | ||
|
||
import os | ||
import random | ||
import tarfile | ||
|
||
random.seed(1024) | ||
|
||
# parse arg from command line | ||
datalist = os.sys.argv[1] | ||
datatype = os.sys.argv[2] | ||
num_gpus = int(os.sys.argv[3]) | ||
num_samples_per_tar = int(os.sys.argv[4]) # only used in shard mode | ||
new_datalist = os.sys.argv[5] | ||
|
||
assert datatype in ["shard", "raw"] | ||
|
||
|
||
filtered_list = [] | ||
with open(datalist, "r") as f: | ||
lines = f.readlines() | ||
lines = [l.strip() for l in lines] | ||
if datatype == "raw": | ||
valid_num = len(lines) // num_gpus * num_gpus | ||
random.shuffle(lines) | ||
filtered_list = lines[:valid_num] | ||
else: | ||
for line in lines: | ||
cnt = 0 | ||
with open(line, "rb") as tar: | ||
stream = tarfile.open(fileobj=tar, mode="r|*") | ||
for tarinfo in stream: | ||
name = tarinfo.name | ||
pos = name.rfind('.') | ||
assert pos > 0 | ||
prefix, postfix = name[:pos], name[pos + 1:] | ||
if postfix == 'txt': | ||
cnt += 1 | ||
if cnt == num_samples_per_tar: | ||
filtered_list.append(line) | ||
valid_num = len(filtered_list) // num_gpus * num_gpus | ||
random.shuffle(filtered_list) | ||
filtered_list = filtered_list[:valid_num] | ||
filtered_list.sort() | ||
print("before filter: {} after filter: {}".format(len(lines), len(filtered_list))) | ||
|
||
with open(new_datalist, "w") as f: | ||
for line in filtered_list: | ||
f.writelines("{}\n".format(line)) |
Oops, something went wrong.