diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 7c9ab0bfd3ea..619ed6923398 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -763,11 +763,18 @@ void quantized_gemm(void* output, at::Tensor& weight, at::Tensor& qscale, int groups, - int bsz) + int bsz, + int hidden_size) { - T* weight16 = (T*)Context::Instance().GetWorkSpace() + - 12 * Context::Instance().GetMaxTokenLenght() * weight.size(1); - + T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz; + + // auto options = at::TensorOptions() + // .dtype(at::kHalf) + // .layout(at::kStrided) + // .device(at::kCUDA) + // .requires_grad(false); + // auto tmp = torch::empty(weight.sizes(), options); + // T* weight16 = (T*)tmp.data_ptr(); launch_dequantize(weight16, (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), @@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); if (q_int8) { - quantized_gemm(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); if (q_int8) { - quantized_gemm( - output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(output.data_ptr(), + (T*)input.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, } else { ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); } - if (q_int8) { - quantized_gemm(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, bsz, Context::Instance().GetCurrentStream()); } + if (q_int8) { - quantized_gemm( - output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz); + quantized_gemm(output.data_ptr(), + intermediate, + weight1, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1449,64 +1467,95 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, template at::Tensor fused_gemm_gelu(at::Tensor& input, at::Tensor& weight, + at::Tensor& weight_scale, at::Tensor& bias, at::Tensor& weight_out, + at::Tensor& weight_out_scale, const float epsilon, bool preLayerNorm, + bool q_int8, bool async_op) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - auto intermediate = - at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); + int intm_dim = q_int8 ? weight.size(0) : weight.size(1); + + // auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), + // {input.size(0), input.size(1), out_size}, + // options); + // T* intermediate = (T*)input.data_ptr() + torch::numel(input); + auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options); + + int bsz = input.size(0) * input.size(1); + float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)intermediate.data_ptr(), + if (q_int8) { + quantized_gemm(intermediate.data_ptr(), + (T*)input.data_ptr(), + weight, + weight_scale, + weight_scale.size(0), + bsz, + input.size(2)); + } else { + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + intm_dim, + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)intermediate.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } launch_bias_gelu((T*)intermediate.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + intm_dim, bsz, Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight_out.size(1), - bsz, - intermediate.size(2), - &alpha, - &gemm_beta, - (T*)weight_out.data_ptr(), - (T*)intermediate.data_ptr(), - (T*)output.data_ptr(), + int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); + auto output = at::empty({input.size(0), input.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output.data_ptr(), + (T*)intermediate.data_ptr(), + weight_out, + weight_out_scale, + weight_out_scale.size(0), + bsz, + input.size(2)); + } else { + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + out_size, + bsz, + intm_dim, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + (T*)intermediate.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } // cudaEventRecord(Context::Instance().GetCompEvent(2), // Context::Instance().GetCurrentStream(true)); return output; diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 49b7c81698a1..d281f8eba39c 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -68,7 +68,7 @@ def __init__(self, merge_count, mlp_extra_grouping) - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), @@ -131,7 +131,6 @@ def forward( if (self.config.fp16 or self.config.q_int8) \ and input.dtype == torch.float: input = input.half() - with torch.no_grad(): attention_output, key, value, context_outputtn_ctx, inp_norm = \ self.attention(input, diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 1ae5bd5ccf9f..cb738ec0c792 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -23,7 +23,7 @@ def forward(self, input): class LinearLayer(nn.Module): - def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): + def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): super(LinearLayer, self).__init__() if weight is not None: self.weight = weight @@ -33,10 +33,12 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): torch.empty(weight_shape, dtype=dtype, device=torch.cuda.current_device())) + self.bias = Parameter( torch.empty(weight_shape[0], dtype=dtype, - device=torch.cuda.current_device())) + device=torch.cuda.current_device())) \ + if bias is not None else None def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) @@ -57,7 +59,7 @@ def forward(self, input): class EmbeddingLayer(nn.Module): - def __init__(self, weight_shape, dtype=torch.float): + def __init__(self, weight_shape, dtype=torch.half): super(EmbeddingLayer, self).__init__() self.weight = Parameter( torch.empty(weight_shape[0], @@ -67,3 +69,28 @@ def __init__(self, weight_shape, dtype=torch.float): def forward(self, input): return F.embedding(input, self.weight) + + +class OPTEmbedding(EmbeddingLayer): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + def __init__(self, weight_shape): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(weight_shape) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, + dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index f577a1a0e1bc..13b743f6f781 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -1,7 +1,7 @@ from torch import nn import deepspeed.ops.transformer as transformer_inference from ..runtime.zero import GatheredParameters -from .layers import LinearLayer, Normalize, EmbeddingLayer +from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding import torch import gc @@ -11,14 +11,18 @@ def load_model_with_checkpoint(r_module, mp_replace, ckpt_type, weight_quantizer=None, - rank=0): + rank=0, + param_names=None, + transformer_config=None, + megatron_v2=False): error_msgs = [] def transpose(data): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None + with torch.no_grad(): + data = data.contiguous() + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None return data.reshape(data.shape[-1], data.shape[-2]) def load(module, prefix): @@ -87,7 +91,7 @@ def load_parameters(module, prefix): else: assert tmp_data.dtype != torch.int8, \ '''Merging of the checkpoints are not supported when using INT8 checkpoint! \ - Please use a as many GPUs as TP-size for the checkpoint''' + Please use a as many GPUs as TP-size for the checkpoint''' all_data = [ sd[j][prefix + n] if type(sd[j][prefix + n]) is list else @@ -138,37 +142,146 @@ def load_parameters(module, prefix): for n, child in module.named_children(): load_parameters(child, prefix + n + '.') else: - module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']))) - module.attention.attn_qkvb = mp_replace.copy( - module.attention.attn_qkvb.data, - sd[0][prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight']))) - module.attention.attn_ob = mp_replace.copy( - module.attention.attn_ob.data, - sd[0][prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'bias']) - module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']))) - module.mlp.inter_b = mp_replace.copy( - module.mlp.inter_b.data, - sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy(module.mlp.output_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']))) - module.mlp.output_b = mp_replace.copy( - module.mlp.output_b.data, - sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias']) + def _transpose(x): + heads = transformer_config.heads // mp_replace.mp_size + attention_head_size = x.shape[-1] // heads + new_x_shape = x.size()[:-1] + (heads, attention_head_size) + x_1 = x.view(*new_x_shape) + (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) + if len(q.shape) > 2: + return torch.cat((q.reshape(q.shape[0], + -1), + k.reshape(q.shape[0], + -1), + v.reshape(q.shape[0], + -1)), + dim=-1).reshape(x.shape) + else: + return torch.cat((q.reshape(-1), + k.reshape(-1), + v.reshape(-1)), + dim=-1).reshape(x.shape) + + # This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor. + # Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist! + def maybe_copy(module, + dst_name, + src_name, + qkv=False, + megatron_v2=False, + split_qkv=False): + if src_name in sd[0]: + dst = getattr(module, dst_name) + tmp = sd[0][src_name].cuda() + if len(dst.shape) == 1: + if split_qkv: + dst = mp_replace.qkv_copy(dst, tmp) + else: + dst = mp_replace.copy(dst, tmp) + if qkv and megatron_v2: + dst = torch.nn.parameter.Parameter( + _transpose(dst).contiguous()) + else: + if split_qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, tmp if weight_quantizer.q_int8 else \ + (transpose(tmp).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, tmp if weight_quantizer.q_int8 else \ + transpose(tmp))) + if qkv and megatron_v2: + scale1 = dst.scale + dst = torch.nn.parameter.Parameter( + _transpose(dst).contiguous()) + dst.scale = scale1 + setattr(module, dst_name, dst) + + # Extending the maybe_copy function for when the q, k, and v are in separate parameters! + def maybe_copy_qkv(module, dst_name, src_names, split_qkv=False): + if src_names[0] in sd[0]: + q = sd[0][src_names[0]] + k = sd[0][src_names[1]] + v = sd[0][src_names[2]] + qkv_data = torch.cat((q, k, v), dim=0) + dst = getattr(module, dst_name) + if len(dst.shape) == 1: + if split_qkv: + dst = mp_replace.qkv_copy(dst, + (qkv_data.cuda()).contiguous()) + else: + dst = mp_replace.copy(dst, qkv_data.cuda()) + else: + if split_qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \ + ((transpose(qkv_data.cuda())).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \ + transpose(qkv_data.cuda()))) + setattr(module, dst_name, dst) + + if len(param_names) == 14: + qkv_w, qkv_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names + elif len(param_names) < 14: + q_w, k_w, v_w, attn_ow, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, _, split_qkv = param_names + else: + q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names + + maybe_copy(module, 'norm_w', prefix + inp_normw) + maybe_copy(module, 'norm_b', prefix + inp_normb) + if len(param_names) == 14: + maybe_copy(module.attention, + 'attn_qkvw', + prefix + qkv_w, + qkv=True, + megatron_v2=megatron_v2, + split_qkv=split_qkv) + maybe_copy(module.attention, + 'attn_qkvb', + prefix + qkv_b, + qkv=True, + megatron_v2=megatron_v2, + split_qkv=split_qkv) + elif len(param_names) < 14: + maybe_copy_qkv(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w], + split_qkv=split_qkv) + else: + maybe_copy_qkv(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w], + split_qkv=split_qkv) + maybe_copy_qkv(module.attention, + 'attn_qkvb', + [prefix + q_b, + prefix + k_b, + prefix + v_b], + split_qkv=split_qkv) + maybe_copy(module.attention, 'attn_ow', prefix + attn_ow) + if len(param_names) >= 14: + maybe_copy(module.attention, 'attn_ob', prefix + attn_ob) + maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw) + maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb) + maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw) + maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb) + maybe_copy(module.mlp, 'output_w', prefix + mlp_ow) + maybe_copy(module.mlp, 'output_b', prefix + mlp_ob) + + try: + import transformers + OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding + except: + OPTLearnedPositionalEmbedding = None layer_policies = { nn.Linear: load, nn.Embedding: load, @@ -176,7 +289,9 @@ def load_parameters(module, prefix): EmbeddingLayer: load, LinearLayer: load, Normalize: load, - transformer_inference.DeepSpeedTransformerInference: load_transformer_layer + transformer_inference.DeepSpeedTransformerInference: load_transformer_layer, + OPTLearnedPositionalEmbedding: load, + OPTEmbedding: load } all_ds_ids = {} @@ -201,14 +316,17 @@ def load_module_recursive(module, prefix='', level=0): ds_shape = child.weight.shape else: ds_shape = child.weight.ds_shape - if child.__class__ is nn.LayerNorm: child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) elif child.__class__ is nn.Linear: - child = LinearLayer(weight=child.weight, bias=child.bias) + child = LinearLayer(weight_shape=child.weight.shape, + bias=child.bias) + setattr(module, name, child) + elif child.__class__ is OPTLearnedPositionalEmbedding: + child = OPTEmbedding(weight_shape=ds_shape) setattr(module, name, child) else: ds_id = None @@ -224,7 +342,8 @@ def load_module_recursive(module, prefix='', level=0): else: load_module_recursive( child, - prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', + prefix if (level == 0 and ckpt_type == 'pp') and param_names[-2] else \ + prefix + name + '.', level + 1) load_module_recursive(r_module) @@ -232,10 +351,11 @@ def load_module_recursive(module, prefix='', level=0): #XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon embedding_weight = None for n, p in r_module.named_parameters(): - if "word_embeddings." in n: + if "word_embeddings." in n or "embed_tokens." in n: embedding_weight = p - assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" - r_module.lm_head.weight = embedding_weight + + if embedding_weight is not None: + r_module.lm_head.weight = embedding_weight for sd_ in sd: del sd_ sd = None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8bf2268064ff..7233d9adfa6f 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -36,13 +36,13 @@ def qkv_copy(self, dst, src): return src src_shape = src.shape dst_shape = dst.shape + if self.out_dim == 0: src_split = torch.split(src.data, src_shape[self.out_dim] // self.mp_size, dim=0) else: src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1) - if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[self.out_dim] == dst_shape[self.out_dim]: return torch.nn.parameter.Parameter(src) @@ -54,7 +54,6 @@ def qkv_copy(self, dst, src): qkv_size, dim=self.out_dim) for src_s in src_split ] - weight_split = [ torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=self.out_dim) for i in range(len(qkv_split[0])) @@ -137,8 +136,7 @@ def get_transformer_name(replaced_module): class GroupQuantizer: - def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): - self.num_groups = num_groups + def __init__(self, q_int8=True, group_size=1, num_bits=8): self.group_size = group_size self.num_bits = num_bits self.q_int8 = q_int8 @@ -149,8 +147,9 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): inputs.scale = torch.empty(1) return inputs q_range = 2**self.num_bits + num_groups = inputs.shape[0] // self.group_size inputs = inputs.to(torch.cuda.current_device()) - input_flat = inputs.reshape(self.num_groups, -1).contiguous() + input_flat = inputs.reshape(num_groups, -1).contiguous() input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range) @@ -160,7 +159,7 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): #print(inputs.shape) inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) input_flat = [ - inputs_split[i].reshape(self.num_groups, + inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2) ] input_min = [ @@ -182,7 +181,7 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], - dim=0).reshape(self.num_groups, + dim=0).reshape(num_groups, -1).contiguous() return out @@ -286,6 +285,11 @@ def _replace_module(module, policy): setattr(module, name, new_module) +selected_policy_g = None +megatron_v2_g = False +transformer_config_g = None + + def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, @@ -325,6 +329,9 @@ def replace_with_policy(child, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) + global selected_policy_g + if selected_policy_g is None: + selected_policy_g = policy if not policy.cuda_graph_supported: # policy says cuda graph is not supported raise an error if set assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable" @@ -340,6 +347,8 @@ def replace_with_policy(child, moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() + global megatron_v2_g + megatron_v2_g = megatron_v2 if not moe or config.moe.type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: @@ -439,6 +448,8 @@ def replace_with_policy(child, bigscience_bloom=bigscience_bloom, max_out_tokens=config.max_out_tokens, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) + global transformer_config_g + transformer_config_g = transformer_config if moe: new_module = transformer_inference.DeepSpeedMoEInference( @@ -553,6 +564,10 @@ def _transpose(x): if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): + if qkvb is None: + attn_block.attn_qkvb = None + if dense_b is None: + attn_block.attn_ob = None pass else: with GatheredParameters([ @@ -911,7 +926,9 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, - ) + param_names=selected_policy_g.get_param_names(), + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) pbar.update(1) else: import gc @@ -935,12 +952,16 @@ def replace_fn(child, _policy, layer_id=0): torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files ] - load_model_with_checkpoint(replaced_module, - sds, - mp_replace, - ckpt_type, - quantizer, - int(rank % tp_split_size)) + load_model_with_checkpoint( + replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % tp_split_size), + param_names=selected_policy_g.get_param_names(), + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) sds = [None for _ in sds] gc.collect() @@ -955,12 +976,16 @@ def replace_fn(child, _policy, layer_id=0): checkpoint["non_tp"][i] ) if base_dir1 else checkpoint["non_tp"][i] sds = [torch.load(ckpt_file, map_location='cpu')] - load_model_with_checkpoint(replaced_module, - sds, - mp_replace, - ckpt_type, - quantizer, - int(rank % tp_split_size)) + load_model_with_checkpoint( + replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % tp_split_size), + param_names=selected_policy_g.get_param_names(), + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) sds = [None for _ in sds] gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") @@ -986,6 +1011,7 @@ def replace_fn(child, _policy, layer_id=0): non_tp_ckpt_name = f'non-tp.pt' ckpt_files = [non_tp_ckpt_name] os.makedirs(config.save_mp_checkpoint_path, exist_ok=True) + if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( @@ -996,7 +1022,7 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') - config = json.dumps({ + new_config = json.dumps({ 'type': ckpt_name, 'base_dir': @@ -1020,7 +1046,7 @@ def replace_fn(child, _policy, layer_id=0): }) with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json", "w") as cfg: - cfg.write(config) + cfg.write(new_config) rep_sd = replaced_module.state_dict() for n, p in replaced_module.named_parameters(): diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 9dcb4ace234e..4dd9e5b0855e 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -92,15 +92,19 @@ class TransformerPolicy(DSPolicy): hf_model_config = None def __init__( - self, - inference=True, - linear_layer=True, - scale_attention=True, - megatron_v2=False, - # the type of activation function used in MLP - mlp_act_func_type=ActivationFuncType.GELU, - # applies layer norm before attention if `pre_attn_norm` is set to True - pre_attn_norm=True): + self, + inference=True, + linear_layer=True, + scale_attention=True, + megatron_v2=False, + # the type of activation function used in MLP + mlp_act_func_type=ActivationFuncType.GELU, + # applies layer norm before attention if `pre_attn_norm` is set to True + pre_attn_norm=True, + # this flag shows whether or not using prefix in loading the checkpoint + use_load_prefix=False, + # whether or not the qkv is stored in the split-format + split_qkv=True): super().__init__() self.inference = inference self.linear_layer = linear_layer @@ -108,6 +112,7 @@ def __init__( self.is_megatron_v2 = megatron_v2 self.mlp_act_func_type = mlp_act_func_type self.pre_attn_norm = pre_attn_norm + self.load_prefix = False def attention(self): """ @@ -139,6 +144,31 @@ def layerNorm(self): """ raise NotImplementedError + def get_param_names(self): + """ + Returns all the transformer parameter names to + be loaded from checkpoint files. The order of + the names is as follows: + 1. Attention weights and biases; + 2. MLP weights and biases; + 3. LayerNorm weights and biases; + In addition to the parameter names, we require two + more parameters to help read the the data correctly + from the checkpoint and split the qkv heads in the + right order: + 1. `use_load_prefix` (Default: False): this specifies + whether we need to use the name of first abstraction + layer of the model for searching the parameter's name + in a checkpoint file. For more information of how this + is used please see + https://github.com/microsoft/DeepSpeed/blob/fix-ckpt-loading/deepspeed/module_inject/load_checkpoint.py#L341 + 2. `split_qkv` (Default: True): we use this flag when splitting + the qkv parameter into heads. If it is False, it means the heads + of q, k, and v are stored together and needs to split in the + DeepSpeed-Inference API. + """ + raise NotImplementedError + class HFBertLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=False): @@ -294,6 +324,22 @@ def layerNorm(self): self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', \ + self.use_load_prefix, \ + self.split_qkv + class HFGPTJLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -339,6 +385,20 @@ def layerNorm(self): self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + def get_param_names(self): + return 'attn.q_proj.weight', \ + 'attn.k_proj.weight', \ + 'attn.v_proj.weight', \ + 'attn.out_proj.weight', \ + 'mlp.fc_in.weight', \ + 'mlp.fc_in.bias', \ + 'mlp.fc_out.weight', \ + 'mlp.fc_out.bias', \ + 'ln_1.weight', \ + 'ln_1.bias', \ + self.use_load_prefix, \ + self.split_qkv + class MegatronLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -463,7 +523,11 @@ def layerNorm(self): class BLOOMLayerPolicy(TransformerPolicy): _orig_layer_class = None - def __init__(self, client_module, inference=True): + def __init__(self, + client_module, + inference=True, + use_load_prefix=True, + split_qkv=False): super().__init__(inference, linear_layer=True) self.client_module = client_module try: @@ -501,12 +565,28 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'self_attention.query_key_value.weight', \ + 'self_attention.query_key_value.bias', \ + 'self_attention.dense.weight', \ + 'self_attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', \ + self.use_load_prefix, \ + self.split_qkv + class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 - def __init__(self, client_module, inference=True, megatron_v2=True): + def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False): super().__init__(inference, megatron_v2=megatron_v2) self.client_module = client_module if GPTNEOXLayerPolicy._orig_layer_class is None: @@ -555,11 +635,27 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', \ + self.use_load_prefix, \ + self.split_qkv + class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None - def __init__(self, client_module, inference=True): + def __init__(self, client_module, inference=True, use_load_prefix=True): super().__init__(inference, linear_layer=True, mlp_act_func_type=ActivationFuncType.ReLU, @@ -568,9 +664,9 @@ def __init__(self, client_module, inference=True): try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer - if isinstance(DSPolicy.hf_model_config, + if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig): - self.pre_attn_norm = self.hf_model_config.do_layer_norm_before + self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before except: HFOPTLayerPolicy._orig_layer_class = None @@ -612,6 +708,26 @@ def layerNorm(self): self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias + def get_param_names(self): + return 'self_attn.q_proj.weight', \ + 'self_attn.q_proj.bias', \ + 'self_attn.k_proj.weight', \ + 'self_attn.k_proj.bias', \ + 'self_attn.v_proj.weight', \ + 'self_attn.v_proj.bias', \ + 'self_attn.out_proj.weight', \ + 'self_attn.out_proj.bias', \ + 'fc1.weight', \ + 'fc1.bias', \ + 'fc2.weight', \ + 'fc2.bias', \ + 'self_attn_layer_norm.weight', \ + 'self_attn_layer_norm.bias', \ + 'final_layer_norm.weight', \ + 'final_layer_norm.bias', \ + self.use_load_prefix, \ + self.split_qkv + # transformer-based policies replace_policies = [ diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 610bd882ecf4..d9df8e98a3de 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -331,7 +331,6 @@ def selfAttention_fp(): False, attn_ow.scale, config.q_int8) - return output, key_layer, value_layer, context_layer, qkv_out[-1] def selfAttention_int8(): @@ -394,7 +393,7 @@ def __init__(self, data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index 4f1c705c55ea..277ba1818286 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -41,10 +41,13 @@ def forward(ctx, if attn_nw is None: output = fused_gemm_gelu(residual_norm, inter_w, + inter_w.scale, inter_b, output_w, + output_w.scale, config.epsilon, config.pre_layer_norm, + config.q_int8, False) else: output, residual_add = mlp_gemm_func(input, @@ -96,7 +99,7 @@ def __init__(self, self.config = config data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),