Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quantized-inference & Add generic support of checkpoint loading #2547

Merged
merged 13 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 95 additions & 46 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,18 @@ void quantized_gemm(void* output,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz)
int bsz,
int hidden_size)
{
T* weight16 = (T*)Context::Instance().GetWorkSpace() +
12 * Context::Instance().GetMaxTokenLenght() * weight.size(1);

T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;

// auto options = at::TensorOptions()
// .dtype(at::kHalf)
// .layout(at::kStrided)
// .device(at::kCUDA)
// .requires_grad(false);
// auto tmp = torch::empty(weight.sizes(), options);
// T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
Expand Down Expand Up @@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);

if (q_int8) {
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);

T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
(T*)input.data_ptr(),
weight,
q_scale,
q_scale.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}

if (q_int8) {
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
bsz,
Context::Instance().GetCurrentStream());
}

if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
intermediate,
weight1,
q_scale1,
q_scale1.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
Expand Down Expand Up @@ -1449,64 +1467,95 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& weight_scale,
at::Tensor& bias,
at::Tensor& weight_out,
at::Tensor& weight_out_scale,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);

// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options);

int bsz = input.size(0) * input.size(1);

float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
if (q_int8) {
quantized_gemm<T>(intermediate.data_ptr(),
(T*)input.data_ptr(),
weight,
weight_scale,
weight_scale.size(0),
bsz,
input.size(2));
} else {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
intm_dim,
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)intermediate.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());

cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
(T*)intermediate.data_ptr(),
weight_out,
weight_out_scale,
weight_out_scale.size(0),
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
out_size,
bsz,
intm_dim,
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
merge_count,
mlp_extra_grouping)

device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
Expand Down Expand Up @@ -131,7 +131,6 @@ def forward(
if (self.config.fp16 or self.config.q_int8) \
and input.dtype == torch.float:
input = input.half()

with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
Expand Down
33 changes: 30 additions & 3 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, input):


class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__()
if weight is not None:
self.weight = weight
Expand All @@ -33,10 +33,12 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
torch.empty(weight_shape,
dtype=dtype,
device=torch.cuda.current_device()))

self.bias = Parameter(
torch.empty(weight_shape[0],
dtype=dtype,
device=torch.cuda.current_device()))
device=torch.cuda.current_device())) \
if bias is not None else None

def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
Expand All @@ -57,7 +59,7 @@ def forward(self, input):


class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.float):
def __init__(self, weight_shape, dtype=torch.half):
super(EmbeddingLayer, self).__init__()
self.weight = Parameter(
torch.empty(weight_shape[0],
Expand All @@ -67,3 +69,28 @@ def __init__(self, weight_shape, dtype=torch.float):

def forward(self, input):
return F.embedding(input, self.weight)


class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

return super().forward(positions + self.offset)
Loading