Skip to content

Commit

Permalink
[train_engine] support fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 15, 2024
1 parent 9406947 commit ee2ac5d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
20 changes: 12 additions & 8 deletions wenet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
return configs


def save_state_dict_and_yaml(state_dict, path: str, infos=None):
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)


def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
Expand All @@ -52,14 +63,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)
save_state_dict_and_yaml(state_dict, path, infos)


def filter_modules(model_state_dict, modules):
Expand Down
6 changes: 4 additions & 2 deletions wenet/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch.distributed.fsdp.wrap import (_or_policy, lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.checkpoint import save_state_dict_and_yaml

from wenet.utils.init_model import (WENET_DECODER_CLASSES,
WENET_ENCODER_CLASSES)
Expand Down Expand Up @@ -48,4 +48,6 @@ def fsdp_save_model(model, save_model_path, info_dict):
rank = int(os.environ.get('RANK', 0))
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
fullstate_save_policy):
save_checkpoint(model, save_model_path, info_dict)
state_dict = model.state_dict()
if rank == 0:
save_state_dict_and_yaml(state_dict, save_model_path, info_dict)
10 changes: 5 additions & 5 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def save_model(model, info_dict):
rank = int(os.environ.get('RANK', 0))
tag = info_dict["tag"]
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
# save ckpt
if info_dict["train_engine"] == "deepspeed":
# NOTE(xcsong): All ranks should call this API, but only rank 0
Expand All @@ -484,13 +485,12 @@ def save_model(model, info_dict):
model_dir, tag),
tag=tag)
os.system("rm -rf {}/{}".format(model_dir, tag))

elif info_dict['train_engine'] == "torch_fsdp":
fsdp_save_model(model, save_model_path, info_dict)
elif rank == 0:
# NOTE(xcsong): For torch_ddp, only rank-0 should call this.
save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
if info_dict['train_engine'] == "torch_fsdp":
fsdp_save_model(model, save_model_path, info_dict)
else:
save_checkpoint(model, save_model_path, info_dict)
save_checkpoint(model, save_model_path, info_dict)
# save yaml
if rank == 0:
with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:
Expand Down

0 comments on commit ee2ac5d

Please sign in to comment.