Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train_engine] support fsdp #2412

Merged
merged 26 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.train_utils import (
add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args,
add_trace_args, init_distributed, init_dataset_and_dataloader,
check_modify_and_save_config, init_optimizer_and_scheduler,
trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model,
log_per_epoch)
add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args,
add_deepspeed_args, add_trace_args, init_distributed,
init_dataset_and_dataloader, check_modify_and_save_config,
init_optimizer_and_scheduler, init_scaler, trace_and_print_model,
wrap_cuda_model, init_summarywriter, save_model, log_per_epoch)


def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
choices=['torch_ddp', 'torch_fsdp', 'deepspeed'],
help='Engine for paralleled training')
parser = add_model_args(parser)
parser = add_dataset_args(parser)
parser = add_ddp_args(parser)
parser = add_deepspeed_args(parser)
parser = add_fsdp_args(parser)
parser = add_trace_args(parser)
args = parser.parse_args()
if args.train_engine == "deepspeed":
Expand Down Expand Up @@ -96,7 +97,7 @@ def main():
writer = init_summarywriter(args)

# Dispatch model from cpu to gpu
model, device = wrap_cuda_model(args, model)
model, device = wrap_cuda_model(args, model, configs)

# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(
Expand All @@ -118,9 +119,7 @@ def main():
int("step_" in tag))

# Init scaler, used for pytorch amp mixed precision training
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
scaler = init_scaler(args)

# Start training loop
start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag)
Expand Down Expand Up @@ -173,6 +172,7 @@ def main():
final_model_path) else None
os.symlink('{}.pt'.format(final_epoch), final_model_path)
writer.close()
dist.destroy_process_group()


if __name__ == '__main__':
Expand Down
20 changes: 12 additions & 8 deletions wenet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
return configs


def save_state_dict_and_infos(state_dict, path: str, infos=None):
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)


def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
Expand All @@ -52,14 +63,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)
save_state_dict_and_infos(state_dict, path, infos)


def filter_modules(model_state_dict, modules):
Expand Down
11 changes: 8 additions & 3 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def train(self, model, optimizer, scheduler, train_data_loader,
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict.get("train_engine", "torch_ddp") == "torch_ddp" and \
(batch_idx + 1) % info_dict["accum_grad"] != 0:
if info_dict.get("train_engine", "torch_ddp") in [
"torch_ddp", "torch_fsdp"
] and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
Expand All @@ -87,6 +88,9 @@ def train(self, model, optimizer, scheduler, train_data_loader,
save_interval = info_dict.get('save_interval', sys.maxsize)
if self.step % save_interval == 0 and self.step != 0 \
and (batch_idx + 1) % info_dict["accum_grad"] == 0:
import torch.distributed as dist
# Ensure all ranks start CV at the same time in step mode
dist.barrier()
xingchensong marked this conversation as resolved.
Show resolved Hide resolved
loss_dict = self.cv(model, cv_data_loader, configs)
model.train()
info_dict.update({
Expand All @@ -100,11 +104,12 @@ def train(self, model, optimizer, scheduler, train_data_loader,
optimizer.param_groups[0]['lr']
})
save_model(model, info_dict)
# Ensure all ranks start Train at the same time in step mode
dist.barrier()
log_per_step(writer, info_dict, timer=self.train_step_timer)
self.step += 1 if (batch_idx +
1) % info_dict["accum_grad"] == 0 else 0


def cv(self, model, cv_data_loader, configs):
''' Cross validation on
'''
Expand Down
115 changes: 115 additions & 0 deletions wenet/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from functools import partial
import os
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
FullStateDictConfig, StateDictType)

from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer
from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.utils.checkpoint import save_state_dict_and_infos
from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES

WENET_ENCODER_LAYERS_CLASSES = {
'transformer_encoder_layer': TransformerEncoderLayer,
'conformer_encoder_layer': ConformerEncoderLayer,
'paraformer_encoder_layer': AliParaformerEncoderLayer,
'squeezeformer_encoder_layer': SqueezeformerEncoderLayer,
'ebranchformer_encoder_layer': EBranchformerEncoderLayer,
'efficient_conformer_encoder_layer': StrideConformerEncoderLayer,
'branchformer_encoder_layer': BranchformerEncoderLayer,
}

WENET_DECODER_LAYERS_CLASSES = {
'transformer_decoder_layer': DecoderLayer,
'paraformer_decoder_layer': SanmDecoderLayer,
# TODO(Mddct):
# 1 wrap transducer's predictor and joint
# 2 wrap paraformer's cif and ignore lstm
}


def wenet_fsdp_wrap_policy(mode):
# different wrap methods
# please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa
assert mode in ['no_shard', 'model', 'zero2', 'zero3']
if mode == 'no_shard':
return None
else:
# TODO(Mddct): Support user customization
# see more wrap methods:
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
if mode == 'model':
enc_dec_wrap_policy = partial(
lambda_auto_wrap_policy,
lambda_fn=lambda module: isinstance(
module,
tuple(WENET_ENCODER_CLASSES.values()) + tuple(
WENET_DECODER_CLASSES.values())))
return enc_dec_wrap_policy
else:
to_wrap_class = set()
to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values()))
to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values()))
layers_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls=to_wrap_class)
Comment on lines +45 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要是能贴个link解释下lambda wrap和transformer wrap的区别就好了,方便学习 (感谢周哥:))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从使用方式上看,lambda是用于整个encoder级别,transformer是用于layer级别,为啥要有这两种不同的划分呢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个wrap 意味着在这个wrap的forward上in out 的梯度 和 optimizer的切分会进行一次all gather 的通性。

fsdp 比较灵活,通过wrap的方式控制“切分”的力度, 所以在 enc dec的力度上相当于只有optimzier的切分,没有梯度的切分,(内存优化相当于zero1) 在每一个layer上的wrap就有了layer级别的切分相当于zero2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return layers_wrap_policy


fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True,
rank0_only=True)


def fsdp_save_model(model, save_model_path, info_dict):
# TODO(Mddct); When the model is large, saving a model will take a long time.
# We only need to keep the sharding in an asynchronous manner, but it is
# good now. This feature will be supported when llm is supported in the future.

rank = int(os.environ.get('RANK', 0))
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
fullstate_save_policy):
state_dict = model.state_dict()
if rank == 0:
save_state_dict_and_infos(state_dict, save_model_path, info_dict)

xingchensong marked this conversation as resolved.
Show resolved Hide resolved

def check_gradient_checkpoint(model):
ckpt_laye_types = []
if hasattr(model, 'encoder') and hasattr(model.encoder,
'gradient_checkpointing'):
if model.encoder.gradient_checkpointing:
model.encoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values())
if hasattr(model, 'decoder') and hasattr(model.decoder,
'gradient_checkpointing'):
if model.decoder.gradient_checkpointing:
model.decoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
return tuple(ckpt_laye_types)


def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple):
# NOTE(Mddct): torch.utils.checkpoint is currently incompatible with
# wenet's model mode. Using this writing method, Please refer to
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa
if len(ckpt_layer_types) == 0:
return
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))
Loading
Loading