Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): support deepspeed #1849

Merged
merged 17 commits into from
May 27, 2023
Merged

feat(train): support deepspeed #1849

merged 17 commits into from
May 27, 2023

Conversation

xingchensong
Copy link
Member

@xingchensong xingchensong commented May 10, 2023

Brief

This PR integrates Deepspeed into wenet, which enables:

  1. faster training (~10% speedup) on small models (i.e., standard conformer with 48M (0.048B) params)
  2. ability to train much larger models (i.e., using deepspeed's zero optimization, we can train ~2B model without pain)

Initial result (part-1):

mode ctc greedy rescore comment
torch.ddp 5.19 4.63 chunk -1, 360 epoch, batch16, 8GPUs, following standard u2pp_conformer config
deepspeed 5.15 4.82 chunk -1, 360 epoch, batch32, 4GPUs, following standard u2pp_conformer config

I believe that we can get the same result within the same training configurations (The current minor differences may be due to variations in batch size and the number of GPUs.). This initial result indicates that the integration of deepspeed is correct.

Initial result (part-2):

mode ctc greedy rescore comment
torch.ddp 5.62 5.03 chunk -1, 360 epoch, batch64, 4GPUs, float16
deepspeed 5.24 4.94 chunk -1, 360 epoch, batch64, 4GPUs, float16
deepspeed 5.42 4.82 chunk -1, 360 epoch, batch64, 4GPUs, bfloat16

We can clearly see that under the same training configurations, deepspeed+bfloat16 is better than torch.ddp+float16.

screenshot-20230519-105243

From tensorboard logs, we can conclude that torch.ddp and deepspeed have the same trend on train_loss/cv_loss/lr.

Benchmark: Training speed on small model

info: 4 * RTX 3090 (24G), fp32 training, 8 dataloader workers and 500 prefetch, 32 batch size, nccl

mode data type time cost per epoch comment
torch.ddp raw 8~9 min
torch.ddp shard 8~9 min
deepspeed raw 7~8 min
deepspeed shard 7~8 min

Benchmark: Ability to train 1.8B model with efficient batchsize

info: 4 * RTX 3090 (24G), batchsize16 per device, bf16 training, 8 dataloader workers and 500 prefetch, nccl

image

about 50min per epoch

TODO

  • enable fp16/bf16 training (The problem here is that DeepSpeed does not attempt to do any automatic casting in the case of mixed-precision training, so we need some hacks to manually cast tensor types)

Limitations

  • Currently only support single node multi gpus

@xingchensong
Copy link
Member Author

xingchensong commented May 15, 2023

I've discovered that torch.autocast integrates seamlessly with DeepSpeed, thereby saving us a significant amount of manual labor when it comes to casting tensor types, especially when fp16/bf16 is enabled.:

with torch.cuda.amp.autocast(cache_enabled=False):
    loss = model_wrapped_by_deepspeed_initilize(inputs)

@xingchensong
Copy link
Member Author

The script I used to test 1.8B model:

#!/bin/bash
# Copyright [2023-05-10] <[email protected], Xingchen Song>
size="1.8B"
stage=stage2
dir=u2pp_conformer_deepspeed_shard_nccl_${size}_${stage}
rm -rf tensorboard/$dir
rm -rf exp/$dir

if [ -d "/usr/local/cuda" ]; then
  export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
  #export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:/usr/local/cuda-10.0/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}:/usr/local/lib/openmpi/:/usr/local/nccl_2.4.7-1+cuda10.0_x86_64/lib
  export CUDA_HOME=/usr/local/cuda
  export CFLAGS="-I$CUDA_HOME/include $CFLAGS"
  export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
  #export LIBRARY_PATH=/usr/local/nccl_2.4.7-1+cuda10.0_x86_64/lib/:$LIBRARY_PATH
  export CUDA_PATH=$CUDA_HOME
fi

bash run.sh \
  --deepspeed true \
  --train_set "train" \
  --data_type "shard" \
  --stage 4 --stop_stage 4 \
  --deepspeed_save_states "model+optimizer" \
  --deepspeed_config conf/ds_$stage.json \
  --train_config conf/train_u2++_conformer_1.8B.yaml \
  --dir exp/$dir

@xingchensong
Copy link
Member Author

xingchensong commented May 26, 2023

I think this PR is ready for final review, we can merge this so that others can start experimenting and then we can fix whatever needs to be fixed. cc @robin1001

@robin1001
Copy link
Collaborator

LGTM. train.py and executor get more complicated, we have to refactor it in the future.

@robin1001 robin1001 merged commit 2505d18 into main May 27, 2023
@robin1001 robin1001 deleted the xcsong-deepspeed branch May 27, 2023 04:20
@xingchensong
Copy link
Member Author

Double-check on efficientconformer-v1-stream:

image

@xingchensong
Copy link
Member Author

LGTM. train.py and executor get more complicated, we have to refactor it in the future.

done in #2055

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants