diff --git a/wenet/bin/train.py b/wenet/bin/train.py index f772ff85f..1ddfe0435 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -30,23 +30,24 @@ from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.train_utils import ( - add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, - add_trace_args, init_distributed, init_dataset_and_dataloader, - check_modify_and_save_config, init_optimizer_and_scheduler, - trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model, - log_per_epoch) + add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args, + add_deepspeed_args, add_trace_args, init_distributed, + init_dataset_and_dataloader, check_modify_and_save_config, + init_optimizer_and_scheduler, init_scaler, trace_and_print_model, + wrap_cuda_model, init_summarywriter, save_model, log_per_epoch) def get_args(): parser = argparse.ArgumentParser(description='training your network') parser.add_argument('--train_engine', default='torch_ddp', - choices=['torch_ddp', 'deepspeed'], + choices=['torch_ddp', 'torch_fsdp', 'deepspeed'], help='Engine for paralleled training') parser = add_model_args(parser) parser = add_dataset_args(parser) parser = add_ddp_args(parser) parser = add_deepspeed_args(parser) + parser = add_fsdp_args(parser) parser = add_trace_args(parser) args = parser.parse_args() if args.train_engine == "deepspeed": @@ -96,7 +97,7 @@ def main(): writer = init_summarywriter(args) # Dispatch model from cpu to gpu - model, device = wrap_cuda_model(args, model) + model, device = wrap_cuda_model(args, model, configs) # Get optimizer & scheduler model, optimizer, scheduler = init_optimizer_and_scheduler( @@ -118,9 +119,7 @@ def main(): int("step_" in tag)) # Init scaler, used for pytorch amp mixed precision training - scaler = None - if args.use_amp: - scaler = torch.cuda.amp.GradScaler() + scaler = init_scaler(args) # Start training loop start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) @@ -173,6 +172,7 @@ def main(): final_model_path) else None os.symlink('{}.pt'.format(final_epoch), final_model_path) writer.close() + dist.destroy_process_group() if __name__ == '__main__': diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 42fc8fe67..d60582716 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -40,6 +40,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: return configs +def save_state_dict_and_infos(state_dict, path: str, infos=None): + torch.save(state_dict, path) + info_path = re.sub('.pt$', '.yaml', path) + if infos is None: + infos = {} + infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(infos) + fout.write(data) + + def save_checkpoint(model: torch.nn.Module, path: str, infos=None): ''' Args: @@ -52,14 +63,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None): state_dict = model.module.state_dict() else: state_dict = model.state_dict() - torch.save(state_dict, path) - info_path = re.sub('.pt$', '.yaml', path) - if infos is None: - infos = {} - infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') - with open(info_path, 'w') as fout: - data = yaml.dump(infos) - fout.write(data) + save_state_dict_and_infos(state_dict, path, infos) def filter_modules(model_state_dict, modules): diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index bd9db9338..11bc24706 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -68,8 +68,9 @@ def train(self, model, optimizer, scheduler, train_data_loader, # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. - if info_dict.get("train_engine", "torch_ddp") == "torch_ddp" and \ - (batch_idx + 1) % info_dict["accum_grad"] != 0: + if info_dict.get("train_engine", "torch_ddp") in [ + "torch_ddp", "torch_fsdp" + ] and (batch_idx + 1) % info_dict["accum_grad"] != 0: context = model.no_sync # Used for single gpu training and DDP gradient synchronization # processes. @@ -87,6 +88,9 @@ def train(self, model, optimizer, scheduler, train_data_loader, save_interval = info_dict.get('save_interval', sys.maxsize) if self.step % save_interval == 0 and self.step != 0 \ and (batch_idx + 1) % info_dict["accum_grad"] == 0: + import torch.distributed as dist + # Ensure all ranks start CV at the same time in step mode + dist.barrier() loss_dict = self.cv(model, cv_data_loader, configs) model.train() info_dict.update({ @@ -100,11 +104,12 @@ def train(self, model, optimizer, scheduler, train_data_loader, optimizer.param_groups[0]['lr'] }) save_model(model, info_dict) + # Ensure all ranks start Train at the same time in step mode + dist.barrier() log_per_step(writer, info_dict, timer=self.train_step_timer) self.step += 1 if (batch_idx + 1) % info_dict["accum_grad"] == 0 else 0 - def cv(self, model, cv_data_loader, configs): ''' Cross validation on ''' diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py new file mode 100644 index 000000000..33871f6f0 --- /dev/null +++ b/wenet/utils/fsdp_utils.py @@ -0,0 +1,115 @@ +from functools import partial +import os +from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, + FullStateDictConfig, StateDictType) + +from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, + transformer_auto_wrap_policy) +from wenet.branchformer.encoder_layer import BranchformerEncoderLayer +from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer +from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer +from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer +from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer +from wenet.transformer.encoder_layer import (ConformerEncoderLayer, + TransformerEncoderLayer) +from wenet.transformer.decoder_layer import DecoderLayer +from wenet.utils.checkpoint import save_state_dict_and_infos +from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES + +WENET_ENCODER_LAYERS_CLASSES = { + 'transformer_encoder_layer': TransformerEncoderLayer, + 'conformer_encoder_layer': ConformerEncoderLayer, + 'paraformer_encoder_layer': AliParaformerEncoderLayer, + 'squeezeformer_encoder_layer': SqueezeformerEncoderLayer, + 'ebranchformer_encoder_layer': EBranchformerEncoderLayer, + 'efficient_conformer_encoder_layer': StrideConformerEncoderLayer, + 'branchformer_encoder_layer': BranchformerEncoderLayer, +} + +WENET_DECODER_LAYERS_CLASSES = { + 'transformer_decoder_layer': DecoderLayer, + 'paraformer_decoder_layer': SanmDecoderLayer, + # TODO(Mddct): + # 1 wrap transducer's predictor and joint + # 2 wrap paraformer's cif and ignore lstm +} + + +def wenet_fsdp_wrap_policy(mode): + # different wrap methods + # please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa + assert mode in ['no_shard', 'model', 'zero2', 'zero3'] + if mode == 'no_shard': + return None + else: + # TODO(Mddct): Support user customization + # see more wrap methods: + # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa + if mode == 'model': + enc_dec_wrap_policy = partial( + lambda_auto_wrap_policy, + lambda_fn=lambda module: isinstance( + module, + tuple(WENET_ENCODER_CLASSES.values()) + tuple( + WENET_DECODER_CLASSES.values()))) + return enc_dec_wrap_policy + else: + to_wrap_class = set() + to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) + to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) + layers_wrap_policy = partial(transformer_auto_wrap_policy, + transformer_layer_cls=to_wrap_class) + return layers_wrap_policy + + +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, + rank0_only=True) + + +def fsdp_save_model(model, save_model_path, info_dict): + # TODO(Mddct); When the model is large, saving a model will take a long time. + # We only need to keep the sharding in an asynchronous manner, but it is + # good now. This feature will be supported when llm is supported in the future. + + rank = int(os.environ.get('RANK', 0)) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, + fullstate_save_policy): + state_dict = model.state_dict() + if rank == 0: + save_state_dict_and_infos(state_dict, save_model_path, info_dict) + + +def check_gradient_checkpoint(model): + ckpt_laye_types = [] + if hasattr(model, 'encoder') and hasattr(model.encoder, + 'gradient_checkpointing'): + if model.encoder.gradient_checkpointing: + model.encoder.gradient_checkpointing = False + ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) + if hasattr(model, 'decoder') and hasattr(model.decoder, + 'gradient_checkpointing'): + if model.decoder.gradient_checkpointing: + model.decoder.gradient_checkpointing = False + ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) + return tuple(ckpt_laye_types) + + +def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): + # NOTE(Mddct): torch.utils.checkpoint is currently incompatible with + # wenet's model mode. Using this writing method, Please refer to + # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa + if len(ckpt_layer_types) == 0: + return + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, + ) + non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types)) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 46a69c52f..5801065de 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext import copy from typing import Optional + import deepspeed import json import logging @@ -28,6 +30,9 @@ from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ +from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, + CPUOffload, MixedPrecision, + sharded_grad_scaler, ShardingStrategy) from deepspeed.runtime.zero.stage_1_and_2 import ( estimate_zero2_model_states_mem_needs_all_live) from deepspeed.runtime.zero.stage3 import ( @@ -36,6 +41,9 @@ convert_zero_checkpoint_to_fp32_state_dict) from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint +from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model, + apply_fsdp_checkpointing, + wenet_fsdp_wrap_policy) from wenet.utils.common import StepTimer from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.ctc_utils import get_blank_id @@ -142,13 +150,48 @@ def add_deepspeed_args(parser): return parser +def add_fsdp_args(parser): + parser.add_argument( + '--dtype', + default='fp32', + choices=['fp32', 'fp16', 'bf16'], + help='when amp is used, dtype is automatically set to fp16.\ + this arg has no effect when deepspeed is enabled.') + parser.add_argument( + '--fsdp_cpu_offload', + default=False, + type=bool, + help='whether to offload parameters to CPU', + ) + parser.add_argument( + '--fsdp_sync_module_states', + type=bool, + default=True, + help='\ + each FSDP module will broadcast module parameters and buffers from \ + rank 0 to ensure that they are replicated across ranks', + ) + parser.add_argument( + '--fsdp_sharding_strategy', + default='zero2', + # TODO(Mddct): pipeline and model parallel (3-D parallelism) + choices=['no_shard', 'model', 'zero2', 'zero3'], + help='Sharding strategy for FSDP. Choose from the following options:\n' + ' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n' + ' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n' + ' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n' + ' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n' + 'For more information, refer to the FSDP API documentation.') + return parser + + def init_distributed(args): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + ', rank {}, world_size {}'.format(rank, world_size)) - if args.train_engine == "torch_ddp": + if args.train_engine in ["torch_ddp", "torch_fsdp"]: torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend) elif args.train_engine == "deepspeed": @@ -159,11 +202,12 @@ def init_distributed(args): def check_modify_and_save_config(args, configs, symbol_table): - if args.train_engine == "torch_ddp": + if args.train_engine in ["torch_ddp", "torch_fsdp"]: if args.use_amp: configs["dtype"] = "fp16" + args.dtype = 'fp16' else: - configs["dtype"] = "fp32" + configs["dtype"] = args.dtype elif args.train_engine == "deepspeed": # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom # dataset, we need to manually ensure that the data is evenly distributed @@ -282,25 +326,19 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): return train_dataset, cv_dataset, train_data_loader, cv_data_loader -def wrap_cuda_model(args, model): +def wrap_cuda_model(args, model, configs=None): local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if hasattr(model, 'encoder'): grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False) else: grad_ckpt = False - # TODO(xcsong): could one GPU use ddp? and int(os.environ.get('WORLD_SIZE', 1)) > 1 if args.train_engine == "torch_ddp": # native pytorch ddp assert (torch.cuda.is_available()) model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=not grad_ckpt) device = torch.device("cuda") - if args.fp16_grad_sync: - from torch.distributed.algorithms.ddp_comm_hooks import ( - default as comm_hooks, ) - model.register_comm_hook(state=None, - hook=comm_hooks.fp16_compress_hook) elif args.train_engine == "deepspeed": # deepspeed # NOTE(xcsong): look in detail how the memory estimator API works: # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion @@ -317,8 +355,50 @@ def wrap_cuda_model(args, model): num_nodes=world_size // local_world_size) device = None # Init device later pass # Init DeepSpeed later + elif args.train_engine == 'torch_fsdp': + assert configs is not None + mixed_precision_dtype = { + 'fp32': torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + }[configs['dtype']] + + sharding_strategy = { + 'model': ShardingStrategy.SHARD_GRAD_OP, + 'zero2': ShardingStrategy.SHARD_GRAD_OP, + 'zero3': ShardingStrategy.FULL_SHARD, + 'no_shard': ShardingStrategy.NO_SHARD, + }[args.fsdp_sharding_strategy] + wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy) + layer_types = check_gradient_checkpoint(model) + model = FSDP( + model, + auto_wrap_policy=wrap_policy, + cpu_offload=CPUOffload(offload_params=True) + if args.fsdp_cpu_offload is True else None, + mixed_precision=MixedPrecision( + param_dtype=mixed_precision_dtype, + reduce_dtype=mixed_precision_dtype, + buffer_dtype=mixed_precision_dtype, + ), + sharding_strategy=sharding_strategy, + limit_all_gathers=True, + use_orig_params=True, + sync_module_states=args.fsdp_sync_module_states, + # init_distributed is called (torch.cuda.set_device), + # we should set device_id, see FSDP api + device_id=torch.cuda.current_device(), + ) + apply_fsdp_checkpointing(model, layer_types) + device = torch.device("cuda") else: logging.error("not supported engine: {}".format(args.train_engine)) + if args.train_engine in ["torch_fsdp", "torch_ddp"]: + if args.fp16_grad_sync: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default as comm_hooks, ) + model.register_comm_hook(state=None, + hook=comm_hooks.fp16_compress_hook) return model, device @@ -395,10 +475,23 @@ def init_summarywriter(args): return writer +def init_scaler(args): + scaler = None + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + elif args.train_engine == 'torch_fsdp': + # why bf16 don't need scaler: + # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596 + if args.dtype in ['fp16']: + scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True) + return scaler + + def save_model(model, info_dict): rank = int(os.environ.get('RANK', 0)) tag = info_dict["tag"] model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) # save ckpt if info_dict["train_engine"] == "deepspeed": # NOTE(xcsong): All ranks should call this API, but only rank 0 @@ -410,13 +503,14 @@ def save_model(model, info_dict): client_state=info_dict) if info_dict["save_states"] == "model_only" and rank == 0: convert_zero_checkpoint_to_fp32_state_dict(model_dir, - "{}/{}.pt".format( - model_dir, tag), + save_model_path, tag=tag) os.system("rm -rf {}/{}".format(model_dir, tag)) + + elif info_dict['train_engine'] == "torch_fsdp": + fsdp_save_model(model, save_model_path, info_dict) elif rank == 0: # NOTE(xcsong): For torch_ddp, only rank-0 should call this. - save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) save_checkpoint(model, save_model_path, info_dict) # save yaml if rank == 0: @@ -467,21 +561,24 @@ def batch_forward(model, batch, scaler, info_dict): else: # fp32 dtype = None - if train_engine == "deepspeed": - # deepspeed - with torch.cuda.amp.autocast(enabled=dtype is not None, - dtype=dtype, - cache_enabled=False): - loss_dict = model(batch, device) - else: - # torch_ddp - # autocast context - # The more details about amp can be found in - # https://pytorch.org/docs/stable/notes/amp_examples.html - with torch.cuda.amp.autocast(scaler is not None): - loss_dict = model(batch, device) - info_dict['loss_dict'] = loss_dict + # autocast context + # The more details about amp can be found in + # https://pytorch.org/docs/stable/notes/amp_examples.html + autocast = { + "deepspeed": + torch.cuda.amp.autocast(enabled=dtype is not None, + dtype=dtype, + cache_enabled=False), + "torch_ddp": + torch.cuda.amp.autocast(enabled=scaler is not None), + "torch_fsdp": + torch.cuda.amp.autocast(enabled=True, dtype=dtype) + if dtype is not None else nullcontext() + }[train_engine] + with autocast: + loss_dict = model(batch, device) + info_dict['loss_dict'] = loss_dict return info_dict @@ -498,12 +595,17 @@ def batch_backward(model, scaler, info_dict): # `scale_loss_wrt_accum_grad + loss.backward()` # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api scaled_loss = model.backward(loss) - elif train_engine == "torch_ddp": + else: + assert train_engine in ["torch_ddp", "torch_fsdp"] scaled_loss = loss / accum_grad - if use_amp: + if scaler is not None: + # fp16 (amp and fsdp) scaler.scale(scaled_loss).backward() else: + # float32 (ddp and fsdp) + # bf16 (fsdp) scaled_loss.backward() + info_dict['loss_dict']['loss'] = scaled_loss for loss_name, loss_value in info_dict['loss_dict'].items(): if loss_value is not None: @@ -539,9 +641,14 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): grad_norm = model.get_global_grad_norm() elif (batch_idx + 1) % accum_grad == 0: # Use mixed precision training - if use_amp: + # fp16 (ddp fsdp) + if scaler is not None: scaler.unscale_(optimizer) - grad_norm = clip_grad_norm_(model.parameters(), clip) + if train_engine == "torch_ddp": + grad_norm = clip_grad_norm_(model.parameters(), clip) + else: + # fsdp + grad_norm = model.clip_grad_norm_(clip) # Must invoke scaler.update() if unscale_() is used in # the iteration to avoid the following error: # RuntimeError: unscale_() has already been called @@ -552,7 +659,10 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): scaler.step(optimizer) scaler.update() else: - grad_norm = clip_grad_norm_(model.parameters(), clip) + if train_engine == "torch_ddp": + grad_norm = clip_grad_norm_(model.parameters(), clip) + else: + grad_norm = model.clip_grad_norm_(clip) if torch.isfinite(grad_norm): optimizer.step() optimizer.zero_grad() @@ -581,8 +691,9 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): rank = int(os.environ.get('RANK', 0)) if tag == "TRAIN" and rank == 0 and writer is not None: - if (train_engine == "deepspeed" and is_gradient_accumulation_boundary) or \ - (train_engine == "torch_ddp" and (batch_idx + 1) % accum_grad == 0): + if (train_engine == "deepspeed" and is_gradient_accumulation_boundary + ) or (train_engine in ["torch_ddp", "torch_fsdp"] and + (batch_idx + 1) % accum_grad == 0): writer.add_scalar('train/train_loss', loss_dict['loss'] * accum_grad, step + 1) writer.add_scalar('train/grad_norm', info_dict['grad_norm'],