Skip to content

Commit

Permalink
[PaddleInference] support ptq and cachekv_quant in BlockMultiHeadAtte…
Browse files Browse the repository at this point in the history
…ntion op (#59951) (#60073)

* support cachekv_quant in blha

---------

Co-authored-by: Wanglongzhi2001 <[email protected]>
  • Loading branch information
RichardWooSJTU and Wanglongzhi2001 authored Dec 18, 2023
1 parent 36c402b commit d8745c1
Show file tree
Hide file tree
Showing 9 changed files with 3,688 additions and 112 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/pybind/inference_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,15 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
static_cast<int64_t *>(paddle_tensor.data<int64_t>()),
shape,
ToPaddleInferPlace(paddle_tensor.place().GetType()));
} else if (paddle_tensor.dtype() == phi::DataType::UINT8) {
tensor.ShareExternalData(
static_cast<uint8_t *>(paddle_tensor.data()),
shape,
ToPaddleInferPlace(paddle_tensor.place().GetType()));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now share_external_data only supports INT32, "
"INT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
"INT64, UINT8, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
}
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
data_type : x

- op : block_multihead_attention_
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, int max_seq_len, int block_size, bool use_neox_style)
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out)
infer_meta :
func : BlockMultiheadAttentionInferMeta
kernel :
func : block_multihead_attention
data_type : qkv
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth
inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out)
support_dygraph_mode : true

Expand Down
79 changes: 77 additions & 2 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& rope_emb,
const MetaTensor& mask,
const MetaTensor& tgt_mask,
const MetaTensor& cache_k_quant_scales,
const MetaTensor& cache_v_quant_scales,
const MetaTensor& cache_k_dequant_scales,
const MetaTensor& cache_v_dequant_scales,
const MetaTensor& qkv_out_scale,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int max_seq_len,
int block_size,
bool use_neox_style,
bool dynamic_cachekv_quant,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const float out_scale,
const std::string& compute_dtype,
MetaTensor* fmha_out,
MetaTensor* qkv_out,
MetaTensor* key_cache_out,
Expand All @@ -159,13 +173,74 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
"The input_dims[1] must be equal to 3 * num_head * dim_head"));

fmha_out->set_dims({input_dims[0], num_head * dim_head});
fmha_out->set_dtype(qkv.dtype());
qkv_out->set_dims(qkv.dims());
qkv_out->set_dtype(qkv.dtype());
key_cache_out->set_dims(key_cache_dims);
key_cache_out->set_dtype(key_cache.dtype());
value_cache_out->set_dims(key_cache_dims);
value_cache_out->set_dtype(value_cache.dtype());

auto FBADtypeCheck = [](const MetaTensor& check_tensor,
const std::string& tensor_name,
const std::string& compute_dtype) {
if (compute_dtype == "bf16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::BFLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp32") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT32,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
}
};

// In the case of quantization enabled, the dtype for computation is
// determined based on compute_dtype.
if (qkv.dtype() == phi::DataType::INT32) {
PADDLE_ENFORCE_NE(
compute_dtype,
"default",
phi::errors::InvalidArgument(
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."));
if (out_scale > 0) {
fmha_out->set_dtype(phi::DataType::INT8);
} else {
if (compute_dtype == "bf16") {
fmha_out->set_dtype(phi::DataType::BFLOAT16);
} else if (compute_dtype == "fp16") {
fmha_out->set_dtype(phi::DataType::FLOAT16);
} else if (compute_dtype == "fp32") {
fmha_out->set_dtype(phi::DataType::FLOAT32);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)",
compute_dtype));
}
}
} else {
if (compute_dtype != "default") {
FBADtypeCheck(qkv, "qkv", compute_dtype);
}
if (out_scale > 0) {
fmha_out->set_dtype(phi::DataType::INT8);
} else {
fmha_out->set_dtype(qkv.dtype());
}
}
}

void Conv1dXPUInferMeta(const MetaTensor& x,
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& rope_emb,
const MetaTensor& mask,
const MetaTensor& tgt_mask,
const MetaTensor& cache_k_quant_scales,
const MetaTensor& cache_v_quant_scales,
const MetaTensor& cache_k_dequant_scales,
const MetaTensor& cache_v_dequant_scales,
const MetaTensor& qkv_out_scale,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int max_seq_len,
int block_size,
bool use_neox_style,
bool dynamic_cachekv_quant,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const float out_scale,
const std::string& compute_dtype,
MetaTensor* fmha_out,
MetaTensor* qkv_out,
MetaTensor* key_cache_out,
Expand Down
Loading

0 comments on commit d8745c1

Please sign in to comment.