Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

offline full fp32 weights reconstruction for zero 2+3 checkpoints #892

Merged
merged 1 commit into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
'''

import os
import stat
import torch
import warnings
import hashlib
import torch.distributed as dist
from collections import OrderedDict
from shutil import copyfile

from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
Expand Down Expand Up @@ -1684,8 +1687,30 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
torch.save(state, save_path)
self._curr_save_path = None

def _get_param_shapes(self):
param_shapes = OrderedDict()
for name, param in self.module.named_parameters():
param_shapes[name] = param.ds_shape if hasattr(param,
"ds_shape") else param.shape
# print(f"saving param {name} {param_shapes[name]}")
return param_shapes

def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
# make executable
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)

def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
zero_sd = dict(
optimizer_state_dict=self.optimizer.state_dict(),
param_shapes=self._get_param_shapes(),
)
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
151 changes: 151 additions & 0 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env python

# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py global_step1 pytorch_model.bin

import argparse
import torch
import glob
import os
from collections import OrderedDict
import deepspeed

# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.


def get_optim_files(checkpoint_dir):

if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")

# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(f"{checkpoint_dir}/*_optim_states.pt"))

if len(optim_files) == 0:
raise FileNotFoundError(
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")

return optim_files


def parse_optim_states(files):
state_dicts = []
for f in files:
state_dicts.append(torch.load(f))

if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"non zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]

# the groups are named differently in each stage
if zero_stage == 2:
fp32_groups_key = "single_partition_of_fp32_groups"
elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups"
else:
raise ValueError(f"unknown zero stage {zero_stage}")

param_shapes = state_dicts[0]["param_shapes"]
fp32_flat_groups = [
state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0]
for i in range(len(state_dicts))
]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]

return zero_stage, world_size, param_shapes, fp32_flat_groups


def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = int(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel


def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
"""
Convert zero 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be
loaded with ``torch.load(file)`` and used for training without DeepSpeed.

Args:
- ``checkpoint_dir``: path to the deepspeed checkpoint folder
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)

"""
print(f"Processing zero checkpoint '{checkpoint_dir}'")

optim_files = get_optim_files(checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files)
print(
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")

# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
# flat buffer - no need to deal with padding since if there is any it will be only in the tail
# of the last partition so there it will be just left out
#
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any

if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)

# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
state_dict = OrderedDict()
offset = 0
total_numel = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel

if zero_stage == 2:
# print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel

elif zero_stage == 3:
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
# print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}")
# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).view(shape)
offset += partitioned_numel + partitioned_padding_numel

# the job is done
print(f"Saving fp32 state dict to {output_file} (total_numel={total_numel})")

torch.save(state_dict, output_file)


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir",
type=str,
help=
"path to the deepspeed checkpoint folder, e.g., path/checkpoint-1/global_step1")
parser.add_argument(
"output_file",
type=str,
help=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-1/pytorch_model.bin)"
)
args = parser.parse_args()

convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)