Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix random token-generation issue + MP-checkpoint loading/saving #2132

Merged
merged 29 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cc0a7db
Fix random token-generation issue + MP-checkpoint loading/saving
Jul 25, 2022
79ba8b9
Merge branch 'master' into ds-inference/bloom-fix
RezaYazdaniAminabadi Jul 25, 2022
b7085ea
small fix
Jul 25, 2022
dc7fa6e
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 25, 2022
fa6b6ae
get the path for saving mp-checkpoints
Jul 26, 2022
f39c78f
Merge branch 'ds-inference/bloom-fix' of github.com:microsoft/DeepSpe…
Jul 26, 2022
13b1aa4
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 26, 2022
1ae7896
bug fix + formatting
jeffra Jul 26, 2022
070f022
fix save_checkpoint path
Jul 26, 2022
c70b529
Merge branch 'ds-inference/bloom-fix' of github.com:microsoft/DeepSpe…
Jul 26, 2022
51c2e7d
Merge branch 'master' into ds-inference/bloom-fix
RezaYazdaniAminabadi Jul 26, 2022
3dcdbe4
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 26, 2022
75b33c6
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 27, 2022
e8ef956
Merge branch 'master' into ds-inference/bloom-fix
tjruwase Jul 27, 2022
4d3c652
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 27, 2022
d794e6c
Modify checkpoint saving to include the config json used during loading
Jul 28, 2022
833f260
git pushMerge branch 'ds-inference/bloom-fix' of github.com:microsoft…
Jul 28, 2022
557521d
set ckpt_mp_size to world_size by default
Jul 28, 2022
b51c447
add missing None
Jul 28, 2022
01003bf
small fix: change None -> 0
Jul 28, 2022
6ea98e8
Merge branch 'master' into ds-inference/bloom-fix
RezaYazdaniAminabadi Jul 28, 2022
043cd98
fix indentation
Jul 28, 2022
bee0b8f
several fixes
Jul 28, 2022
4d5a4ac
support checkpoint as dict or json file
jeffra Jul 28, 2022
b8237c7
fix for non-bloom models
jeffra Jul 28, 2022
007551e
add default if parallelization doesn't exist
jeffra Jul 28, 2022
851a681
fix the path to save non-tp checkpoint
Jul 28, 2022
a53fe08
Merge branch 'ds-inference/bloom-fix' of github.com:microsoft/DeepSpe…
Jul 28, 2022
b249f27
Merge branch 'master' into ds-inference/bloom-fix
jeffra Jul 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def init_inference(model,
moe_experts=1,
moe_type='standard',
args=None,
enable_cuda_graph=False):
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
"""Initialize the DeepSpeed InferenceEngine.

Arguments:
Expand Down Expand Up @@ -304,6 +305,7 @@ def init_inference(model,
moe_experts,
moe_type,
args,
enable_cuda_graph)
enable_cuda_graph,
save_mp_checkpoint_path)

return engine
77 changes: 45 additions & 32 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject.replace_module import replace_transformer_layer
from ..utils import logger
from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
Expand Down Expand Up @@ -50,7 +49,8 @@ def __init__(self,
moe_experts=1,
moe_type='standard',
config=None,
enable_cuda_graph=False):
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
"""
Args:
model: torch.nn.Module
Expand Down Expand Up @@ -130,7 +130,8 @@ def __init__(self,
moe_experts,
moe_type,
training_mp_size,
self.checkpoint if replace_with_kernel_inject else None)
self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
Expand All @@ -139,12 +140,17 @@ def __init__(self,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None)
checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)

device = torch.cuda.current_device()
logger.info(f"Place model to device: {device}")
self.module.to(device)

if self.mp_world_size > 1:
_rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device())
dist.broadcast(_rng_state, 0)
torch.cuda.set_rng_state(_rng_state.cpu())

if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
Expand Down Expand Up @@ -315,32 +321,40 @@ def _apply_injection_policy(self,
moe_experts=1,
moe_type='standard',
training_mp_size=1,
checkpoint_dir=None):
checkpoint_dir=None,
save_mp_checkpoint_path=False):
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir) if checkpoint_dir is not None else None
replace_transformer_layer(client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint=checkpoint)
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else (None,
None,
None,
0)
replace_transformer_layer(
client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dict=checkpoint,
save_mp_checkpoint_path=save_mp_checkpoint_path,
)

def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
Expand Down Expand Up @@ -380,8 +394,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
Expand Down
110 changes: 79 additions & 31 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer
import torch


def load_model_with_checkpoint(r_module, sd, mp_replace):
def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
error_msgs = []

def transpose(data):
Expand All @@ -29,33 +30,76 @@ def load(module, prefix):
module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias'])

def load_transformer_layer(module, prefix):
module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(
module.attention.attn_qkvw.data,
transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight']))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(
module.attention.attn_ow.data,
transpose(sd[prefix + 'self_attention.dense.' + 'weight']))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + 'bias'])
module.mlp.inter_w = mp_replace.copy(
module.mlp.inter_w.data,
transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight']))
module.mlp.inter_b = mp_replace.copy(module.mlp.inter_b.data,
sd[prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(
module.mlp.output_w.data,
transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight']))
module.mlp.output_b = mp_replace.copy(module.mlp.output_b.data,
sd[prefix + 'mlp.dense_4h_to_h.' + 'bias'])
if ckpt_type == "tp":

def load_parameters(module, prefix):
for n, p in module.named_parameters():
if len(n.split('.')) == 1:
src_shape = sd[prefix + n].shape
dst_shape = p.shape

if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[0] == dst_shape[0] and src_shape[
1] == dst_shape[1]:
p.data.copy_(sd[prefix + n])
else:
if src_shape[0] != dst_shape[0]:
weight_split = torch.split(
sd[prefix + n],
dst_shape[0],
dim=0)[rank].to(
torch.cuda.current_device()).contiguous()
else:
weight_split = torch.split(
sd[prefix + n],
dst_shape[1],
dim=1)[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(weight_split.contiguous())
else:
if src_shape[0] == dst_shape[0]:
p.data.copy_(sd[prefix + n])
else:
bias_split = torch.split(
sd[prefix + n],
dst_shape[-1])[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(bias_split)

load_parameters(module, prefix)
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(
module.attention.attn_qkvw.data,
transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight']))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(
module.attention.attn_ow.data,
transpose(sd[prefix + 'self_attention.dense.' + 'weight']))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(
module.mlp.inter_w.data,
transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight']))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(
module.mlp.output_w.data,
transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight']))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[prefix + 'mlp.dense_4h_to_h.' + 'bias'])

layer_policies = {
nn.Linear: load,
Expand Down Expand Up @@ -95,6 +139,9 @@ def load_module_recursive(module, prefix='', level=0):
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child)
elif child.__class__ is nn.Linear:
child = LinearLayer(weight=child.weight, bias=child.bias)
setattr(module, name, child)
else:
ds_id = None
if hasattr(child.weight, 'ds_id'):
Expand All @@ -107,9 +154,10 @@ def load_module_recursive(module, prefix='', level=0):

layer_policies[child.__class__](child, prefix + name + '.')
else:
load_module_recursive(child,
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(
child,
prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.',
level + 1)

load_module_recursive(r_module)

Expand Down
Loading