Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic loading of checkpoints in deepspeed-inference #2405

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion csrc/transformer/inference/includes/inference_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ class Context {
inline size_t GetMaxTokenLenght() const { return _max_seq_len; }

cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
inline void advance_tokens()
{
if (_num_tokens >= _max_seq_len)
printf(
"Requesting to generate more tokens (%d) than max-seq-len allowed by cache (%d)\n",
_num_tokens,
_max_seq_len);
assert(_num_tokens < _max_seq_len);
_num_tokens++;
}

size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
Expand All @@ -182,7 +192,7 @@ class Context {

inline unsigned current_tokens() const { return _num_tokens; }

inline void advance_tokens() { _num_tokens++; }
//inline void advance_tokens() { _num_tokens++; }

cudaStream_t GetCommStream(bool async_op = false)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
merge_count,
mlp_extra_grouping)

device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
Expand Down
33 changes: 29 additions & 4 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, input):


class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__()
if weight is not None:
self.weight = weight
Expand All @@ -34,9 +34,10 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
dtype=dtype,
device=torch.cuda.current_device()))
self.bias = Parameter(
torch.empty(weight_shape[0],
dtype=dtype,
device=torch.cuda.current_device()))
torch.empty(
weight_shape[0],
dtype=dtype,
device=torch.cuda.current_device())) if bias is not None else None

def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
Expand Down Expand Up @@ -67,3 +68,27 @@ def __init__(self, weight_shape, dtype=torch.float):

def forward(self, input):
return F.embedding(input, self.weight)

class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

return super().forward(positions + self.offset)
140 changes: 102 additions & 38 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn
import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer
from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
import torch
import gc

Expand All @@ -11,7 +11,9 @@ def load_model_with_checkpoint(r_module,
mp_replace,
ckpt_type,
weight_quantizer=None,
rank=0):
rank=0,
transformer_config=None,
param_names=None):
error_msgs = []

def transpose(data):
Expand Down Expand Up @@ -138,45 +140,103 @@ def load_parameters(module, prefix):
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight'])))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[0][prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight'])))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[0][prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight'])))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(module.mlp.output_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight'])))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias'])

def maybe_copy(module, dst_name, src_name, qkv=False):
if src_name in sd[0]:
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if qkv:
dst = mp_replace.qkv_copy(
dst,
(sd[0][src_name]).contiguous())
else:
dst = mp_replace.copy(dst, sd[0][src_name])
else:
if qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \
((transpose(sd[0][src_name])).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \
transpose(sd[0][src_name])))
setattr(module, dst_name, dst)
def maybe_copy1(module, dst_name, src_names, qkv=False):
if src_names[0] in sd[0]:
q = sd[0][src_names[0]]
k = sd[0][src_names[1]]
v = sd[0][src_names[2]]
qkv_data = torch.cat((q, k, v), dim=0)
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if qkv:
dst = mp_replace.qkv_copy(dst,
(qkv_data).contiguous())
else:
dst = mp_replace.copy(dst, qkv_data)
else:
if qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data if weight_quantizer.q_int8 else \
((transpose(qkv_data)).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data if weight_quantizer.q_int8 else \
transpose(qkv_data)))
setattr(module, dst_name, dst)
if len(param_names) == 12:
qkv_w, qkv_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb = param_names
elif len(param_names) < 12:
q_w, k_w, v_w, attn_ow, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb = param_names
else:
q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb = param_names
maybe_copy(module, 'norm_w', prefix + inp_normw)
maybe_copy(module, 'norm_b', prefix + inp_normb)
if len(param_names) == 12:
maybe_copy(module.attention, 'attn_qkvw', prefix + qkv_w, qkv=True)
maybe_copy(module.attention, 'attn_qkvb', prefix + qkv_b, qkv=True)
elif len(param_names) < 12:
maybe_copy1(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w])
else:
maybe_copy1(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w])
maybe_copy1(module.attention,
'attn_qkvb',
[prefix + q_b,
prefix + k_b,
prefix + v_b])
maybe_copy(module.attention, 'attn_ow', prefix + attn_ow)
if len(param_names) > 12:
maybe_copy(module.attention, 'attn_ob', prefix + attn_ob)
maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw)
maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb)
maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw)
maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb)
maybe_copy(module.mlp, 'output_w', prefix + mlp_ow)
maybe_copy(module.mlp, 'output_b', prefix + mlp_ob)
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
layer_policies = {
nn.Linear: load,
nn.Embedding: load,
nn.LayerNorm: load,
EmbeddingLayer: load,
LinearLayer: load,
Normalize: load,
transformer_inference.DeepSpeedTransformerInference: load_transformer_layer
transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
OPTLearnedPositionalEmbedding: load,
OPTEmbedding: load
}

all_ds_ids = {}
Expand Down Expand Up @@ -210,6 +270,9 @@ def load_module_recursive(module, prefix='', level=0):
elif child.__class__ is nn.Linear:
child = LinearLayer(weight=child.weight, bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
setattr(module, name, child)
else:
ds_id = None
if hasattr(child.weight, 'ds_id'):
Expand All @@ -222,20 +285,21 @@ def load_module_recursive(module, prefix='', level=0):

layer_policies[child.__class__](child, prefix + name + '.')
else:

load_module_recursive(
child,
prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.',
level + 1)

load_module_recursive(r_module)

#XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon
embedding_weight = None
for n, p in r_module.named_parameters():
if "word_embeddings." in n:
if "word_embeddings." in n or 'embed_tokens.' in n:
embedding_weight = p
assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist"
r_module.lm_head.weight = embedding_weight
if hasattr(r_module, 'lm_head'):
if embedding_weight is not None:
r_module.lm_head.weight = embedding_weight
for sd_ in sd:
del sd_
sd = None
Expand Down
20 changes: 17 additions & 3 deletions deepspeed/module_inject/replace_module.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
return out


transformer_config_g = None
selected_policy_g = None

def _module_match(module):
for policy in generic_policies:
policy = policy()
Expand Down Expand Up @@ -285,7 +288,6 @@ def _replace_module(module, policy):
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)


def replace_transformer_layer(orig_layer_impl,
model,
policy=None,
Expand Down Expand Up @@ -361,6 +363,9 @@ def replace_with_policy(child,
inference=False,
layer_id=0):
policy = policy_cls(child, inference=inference)
global selected_policy_g
if selected_policy_g is None:
selected_policy_g = policy
if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set
assert not enable_cuda_graph, "cuda graph is not supported with this model, please disable"
Expand Down Expand Up @@ -472,6 +477,9 @@ def replace_with_policy(child,
bigscience_bloom=bigscience_bloom,
max_out_tokens=max_out_tokens,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx)
global transformer_config_g
if transformer_config_g is None:
transformer_config_g = transformer_config

if quantize and quantize_settings is not None:
(quantization_scales,
Expand Down Expand Up @@ -978,6 +986,8 @@ def replace_fn(child, _policy, layer_id=0):
mp_replace,
ckpt_type,
quantizer,
transformer_config=transformer_config_g,
param_names=selected_policy_g.get_param_names(),
)
pbar.update(1)
else:
Expand Down Expand Up @@ -1007,7 +1017,9 @@ def replace_fn(child, _policy, layer_id=0):
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size))
int(rank % tp_split_size),
transformer_config=transformer_config_g,
param_names=selected_policy_g.get_param_names())
sds = [None for _ in sds]
gc.collect()

Expand All @@ -1027,7 +1039,9 @@ def replace_fn(child, _policy, layer_id=0):
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size))
int(rank % tp_split_size),
transformer_config=transformer_config_g,
param_names=selected_policy_g.get_param_names())
sds = [None for _ in sds]
gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
Expand Down
Loading