Skip to content

Commit

Permalink
Fix random token-generation issue + MP-checkpoint loading/saving (#2132)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Jul 29, 2022
1 parent 57140e8 commit 556f005
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 87 deletions.
6 changes: 4 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def init_inference(model,
moe_experts=1,
moe_type='standard',
args=None,
enable_cuda_graph=False):
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
"""Initialize the DeepSpeed InferenceEngine.
Arguments:
Expand Down Expand Up @@ -304,6 +305,7 @@ def init_inference(model,
moe_experts,
moe_type,
args,
enable_cuda_graph)
enable_cuda_graph,
save_mp_checkpoint_path)

return engine
78 changes: 44 additions & 34 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject.replace_module import replace_transformer_layer
from ..utils import logger
from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
Expand Down Expand Up @@ -50,7 +49,8 @@ def __init__(self,
moe_experts=1,
moe_type='standard',
config=None,
enable_cuda_graph=False):
enable_cuda_graph=False,
save_mp_checkpoint_path=None):
"""
Args:
model: torch.nn.Module
Expand Down Expand Up @@ -130,7 +130,8 @@ def __init__(self,
moe_experts,
moe_type,
training_mp_size,
self.checkpoint if replace_with_kernel_inject else None)
self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
Expand All @@ -139,12 +140,17 @@ def __init__(self,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None)
checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None,
save_mp_checkpoint_path=save_mp_checkpoint_path)

device = torch.cuda.current_device()
logger.info(f"Place model to device: {device}")
self.module.to(device)

if self.mp_world_size > 1:
_rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device())
dist.broadcast(_rng_state, 0)
torch.cuda.set_rng_state(_rng_state.cpu())

if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
Expand Down Expand Up @@ -226,9 +232,9 @@ def _validate_args(self, mpu):
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self.checkpoint is not None and not isinstance(self.checkpoint, str):
if self.checkpoint is not None and not isinstance(self.checkpoint, (str, dict)):
raise ValueError(
f"checkpoint must be None or a str, got {type(self.checkpoint)}")
f"checkpoint must be None, str or dict, got {type(self.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float]
if self.dtype not in supported_dtypes:
Expand Down Expand Up @@ -315,32 +321,37 @@ def _apply_injection_policy(self,
moe_experts=1,
moe_type='standard',
training_mp_size=1,
checkpoint_dir=None):
checkpoint_dir=None,
save_mp_checkpoint_path=False):
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir) if checkpoint_dir is not None else None
replace_transformer_layer(client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint=checkpoint)
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
replace_transformer_layer(
client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size,
checkpoint_dict=checkpoint,
save_mp_checkpoint_path=save_mp_checkpoint_path,
)

def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
Expand Down Expand Up @@ -380,8 +391,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
Expand Down
110 changes: 79 additions & 31 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer
import torch


def load_model_with_checkpoint(r_module, sd, mp_replace):
def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
error_msgs = []

def transpose(data):
Expand All @@ -29,33 +30,76 @@ def load(module, prefix):
module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias'])

def load_transformer_layer(module, prefix):
module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(
module.attention.attn_qkvw.data,
transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight']))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(
module.attention.attn_ow.data,
transpose(sd[prefix + 'self_attention.dense.' + 'weight']))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + 'bias'])
module.mlp.inter_w = mp_replace.copy(
module.mlp.inter_w.data,
transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight']))
module.mlp.inter_b = mp_replace.copy(module.mlp.inter_b.data,
sd[prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(
module.mlp.output_w.data,
transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight']))
module.mlp.output_b = mp_replace.copy(module.mlp.output_b.data,
sd[prefix + 'mlp.dense_4h_to_h.' + 'bias'])
if ckpt_type == "tp":

def load_parameters(module, prefix):
for n, p in module.named_parameters():
if len(n.split('.')) == 1:
src_shape = sd[prefix + n].shape
dst_shape = p.shape

if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[0] == dst_shape[0] and src_shape[
1] == dst_shape[1]:
p.data.copy_(sd[prefix + n])
else:
if src_shape[0] != dst_shape[0]:
weight_split = torch.split(
sd[prefix + n],
dst_shape[0],
dim=0)[rank].to(
torch.cuda.current_device()).contiguous()
else:
weight_split = torch.split(
sd[prefix + n],
dst_shape[1],
dim=1)[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(weight_split.contiguous())
else:
if src_shape[0] == dst_shape[0]:
p.data.copy_(sd[prefix + n])
else:
bias_split = torch.split(
sd[prefix + n],
dst_shape[-1])[rank].to(
torch.cuda.current_device()).contiguous()
p.data.copy_(bias_split)

load_parameters(module, prefix)
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(
module.attention.attn_qkvw.data,
transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight']))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(
module.attention.attn_ow.data,
transpose(sd[prefix + 'self_attention.dense.' + 'weight']))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(
module.mlp.inter_w.data,
transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight']))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(
module.mlp.output_w.data,
transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight']))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[prefix + 'mlp.dense_4h_to_h.' + 'bias'])

layer_policies = {
nn.Linear: load,
Expand Down Expand Up @@ -95,6 +139,9 @@ def load_module_recursive(module, prefix='', level=0):
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child)
elif child.__class__ is nn.Linear:
child = LinearLayer(weight=child.weight, bias=child.bias)
setattr(module, name, child)
else:
ds_id = None
if hasattr(child.weight, 'ds_id'):
Expand All @@ -107,9 +154,10 @@ def load_module_recursive(module, prefix='', level=0):

layer_policies[child.__class__](child, prefix + name + '.')
else:
load_module_recursive(child,
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(
child,
prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.',
level + 1)

load_module_recursive(r_module)

Expand Down
Loading

0 comments on commit 556f005

Please sign in to comment.