From 1c97aa69632f2d8aae7622a178c225f8164bbe26 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 27 Apr 2023 19:15:17 +0800 Subject: [PATCH] xpu quant weight only (#53306) --- .../fused_multi_transformer_xpu_quant_pass.cc | 65 ++++++++++--------- paddle/fluid/framework/ir/xpu/pass_utils.cc | 7 ++ paddle/fluid/framework/ir/xpu/quant_utils.cc | 13 ++++ paddle/fluid/inference/analysis/argument.h | 6 ++ .../inference/analysis/ir_pass_manager.cc | 8 +++ paddle/fluid/inference/api/analysis_config.cc | 21 ++++++ .../fluid/inference/api/analysis_predictor.cc | 4 ++ .../inference/api/paddle_analysis_config.h | 14 ++++ paddle/fluid/pybind/inference_api.cc | 5 ++ 9 files changed, 113 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc index fab466a50637e..b9868a81135d4 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc @@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, with_time_step, with_seq_lengths, with_src_mask); + int quant_weight_bits = + Has("quant_weight_bits") ? Get("quant_weight_bits") : -1; int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, // quant weight nodes // w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight] std::vector> w_nodes_vec(4); - std::vector> w_int16_nodes_vec(4); + std::vector> w_intx_nodes_vec(4); std::vector> w_max_nodes_vec(4); - std::vector> w_int16_names_vec(4); + std::vector> w_intx_names_vec(4); std::vector> w_max_names_vec(4); auto quant_func = [&](const std::string& input_name, std::vector* w_nodes, - std::vector* w_int16_nodes, + std::vector* w_intx_nodes, std::vector* w_max_nodes, - std::vector* w_int16_names, + std::vector* w_intx_names, std::vector* w_max_names, bool need_transpose) { - typedef int16_t TW; - auto w_names = fused_mt->Op()->Input(input_name); for (auto w_name : w_names) { Node* w_node = FindNodeWithName(graph, w_name); - Node* w_int16 = nullptr; + Node* w_intx = nullptr; Node* w_max = nullptr; PADDLE_ENFORCE_NE( w_node, nullptr, platform::errors::Fatal("w node should not be nullptr")); - PrepareWeight( - graph, scope, block, w_node, &w_int16, &w_max, need_transpose); + if (quant_weight_bits == 8) { + PrepareWeight( + graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + } else { + PrepareWeight( + graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + } w_nodes->push_back(w_node); - w_int16_nodes->push_back(w_int16); + w_intx_nodes->push_back(w_intx); w_max_nodes->push_back(w_max); } for (size_t i = 0; i < w_names.size(); ++i) { - w_int16_names->push_back(w_int16_nodes->at(i)->Name()); + w_intx_names->push_back(w_intx_nodes->at(i)->Name()); w_max_names->push_back(w_max_nodes->at(i)->Name()); } PADDLE_ENFORCE_EQ( @@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, static_cast(w_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), - w_int16_nodes->size(), + w_intx_nodes->size(), platform::errors::Fatal( - "The size of w_names(%d) should be equal to w_int16_nodes(%d)", + "The size of w_names(%d) should be equal to w_intx_nodes(%d)", static_cast(w_names.size()), - static_cast(w_int16_nodes->size()))); + static_cast(w_intx_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), w_max_nodes->size(), @@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, static_cast(w_max_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), - w_int16_names->size(), + w_intx_names->size(), platform::errors::Fatal( - "The size of w_names(%d) should be equal to w_int16_names(%d)", + "The size of w_names(%d) should be equal to w_intx_names(%d)", static_cast(w_names.size()), - static_cast(w_int16_names->size()))); + static_cast(w_intx_names->size()))); PADDLE_ENFORCE_EQ( w_names.size(), w_max_names->size(), @@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, }; quant_func("QKVW", &(w_nodes_vec[0]), - &(w_int16_nodes_vec[0]), + &(w_intx_nodes_vec[0]), &(w_max_nodes_vec[0]), - &(w_int16_names_vec[0]), + &(w_intx_names_vec[0]), &(w_max_names_vec[0]), false); quant_func("OutLinearW", &(w_nodes_vec[1]), - &(w_int16_nodes_vec[1]), + &(w_intx_nodes_vec[1]), &(w_max_nodes_vec[1]), - &(w_int16_names_vec[1]), + &(w_intx_names_vec[1]), &(w_max_names_vec[1]), true); quant_func("FFN1Weight", &(w_nodes_vec[2]), - &(w_int16_nodes_vec[2]), + &(w_intx_nodes_vec[2]), &(w_max_nodes_vec[2]), - &(w_int16_names_vec[2]), + &(w_intx_names_vec[2]), &(w_max_names_vec[2]), true); quant_func("FFN2Weight", &(w_nodes_vec[3]), - &(w_int16_nodes_vec[3]), + &(w_intx_nodes_vec[3]), &(w_max_nodes_vec[3]), - &(w_int16_names_vec[3]), + &(w_intx_names_vec[3]), &(w_max_names_vec[3]), true); @@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, name_caches.at("CacheKVOut")); fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out")); - fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]); + fused_mt_xpu_op_desc->SetInput("qkvw", w_intx_names_vec[0]); fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]); - fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]); + fused_mt_xpu_op_desc->SetInput("out_linear_w", w_intx_names_vec[1]); fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]); - fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]); + fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_intx_names_vec[2]); fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]); - fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]); + fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_intx_names_vec[3]); fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]); if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) { fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0); @@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, } // link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to // fused_mt_xpu - for (auto nodes : w_int16_nodes_vec) { + for (auto nodes : w_intx_nodes_vec) { for (auto node : nodes) { IR_NODE_LINK_TO(node, fused_mt); } diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index bec10d421c7c2..aaa117d363a5f 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -193,6 +193,13 @@ template void PrepareWeight(Graph* graph, Node** dst, Node** dst_max, bool transpose); +template void PrepareWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* src, + Node** dst, + Node** dst_max, + bool transpose); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index ae697708ddcd7..d075d42d29506 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -207,6 +207,16 @@ void QuantFP32ToIntX(const float* src_ptr, } } +template <> +void QuantFP32ToIntX(const float* src_ptr, + int8_t* dst_ptr, + float max_val, + int numel) { + for (int i = 0; i < numel; i++) { + dst_ptr[i] = Fp32ToIntx(src_ptr[i], max_val); + } +} + template void PrepareWeight(phi::DenseTensor* weight, phi::DenseTensor* weight_max, @@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight, template void PrepareWeight(phi::DenseTensor* weight, phi::DenseTensor* weight_max, bool transpose); +template void PrepareWeight(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose); } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 014f4e828efc7..88e6749223008 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -289,6 +289,12 @@ struct Argument { DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool); DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int); DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool); + DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_bits, + XpuQuantPostDynamicWeightBits, + int); + DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types, + XpuQuantPostDynamicOpTypss, + std::vector); DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index c3c2fb6d80ffd..1d87edcd3404c 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -308,6 +308,14 @@ void IRPassManager::CreatePasses(Argument *argument, } bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding(); pass->Set("use_fc_padding", new bool(use_fc_padding)); + } else if (pass_name == "fused_multi_transformer_xpu_quant_pass") { + auto op_types = argument->xpu_quant_post_dynamic_op_types(); + if (std::count(op_types.begin(), + op_types.end(), + "fused_multi_transformer") > 0) { + pass->Set("quant_weight_bits", + new int(argument->xpu_quant_post_dynamic_weight_bits())); + } } pre_pass = pass_name; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index afb2dcd981fa8..3fa947ce27daf 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -196,6 +196,14 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { Update(); } +void AnalysisConfig::SetXpuConfig( + int quant_post_dynamic_weight_bits, + const std::vector &quant_post_dynamic_op_types) { + xpu_quant_post_dynamic_weight_bits_ = quant_post_dynamic_weight_bits; + xpu_quant_post_dynamic_op_types_ = quant_post_dynamic_op_types; + Update(); +} + void AnalysisConfig::EnableCustomDevice(const std::string &device_type, int device_id, Precision precision_mode) { @@ -489,6 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(xpu_precision_); CP_MEMBER(xpu_adaptive_seqlen_); CP_MEMBER(xpu_enable_multi_stream_); + CP_MEMBER(xpu_quant_post_dynamic_weight_bits_); + CP_MEMBER(xpu_quant_post_dynamic_op_types_); // Lite OpenCL Related CP_MEMBER(use_opencl_); @@ -1091,6 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << xpu_precision_; ss << xpu_adaptive_seqlen_; ss << xpu_enable_multi_stream_; + ss << xpu_quant_post_dynamic_weight_bits_; + for (auto op_type : xpu_quant_post_dynamic_op_types_) { + ss << op_type; + } ss << use_npu_; ss << npu_device_id_; @@ -1331,6 +1345,13 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)}); os.InsertRow( {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)}); + os.InsertRow({"xpu_quant_post_dynamic_weight_bits", + std::to_string(xpu_quant_post_dynamic_weight_bits_)}); + std::vector op_types{"xpu_quant_post_dynamic_op_types"}; + for (auto op_type : xpu_quant_post_dynamic_op_types_) { + op_types.push_back(op_type); + } + os.InsertRow(op_types); } os.InsetDivider(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 219c3c2754c68..4ef15fbbc15dc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1426,6 +1426,10 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_); argument_->SetXpuDeviceId(config_.xpu_device_id_); argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_); + argument_->SetXpuQuantPostDynamicWeightBits( + config_.xpu_quant_post_dynamic_weight_bits_); + argument_->SetXpuQuantPostDynamicOpTypss( + config_.xpu_quant_post_dynamic_op_types_); #endif auto *pass_builder = config_.pass_builder(); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 10b92debdace1..3300fe8a9b8dc 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -288,6 +288,18 @@ struct PD_INFER_DECL AnalysisConfig { bool adaptive_seqlen = false, bool enable_multi_stream = false); + /// + /// \brief configs of XPU + /// + /// \param quant_post_dynamic_weight_bits Weight bits used in dynamic post + /// quantization. Optional value: -1, 8, 16. Default value is -1, means using + /// the recommended way. \param quant_post_dynamic_op_types Ops used in + /// dynamic post quantization. + /// + void SetXpuConfig( + int quant_post_dynamic_weight_bits = -1, + const std::vector& quant_post_dynamic_op_types = {}); + /// /// \brief configs of IPU /// @@ -1181,6 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig { std::string xpu_precision_; bool xpu_adaptive_seqlen_; bool xpu_enable_multi_stream_; + int xpu_quant_post_dynamic_weight_bits_{-1}; + std::vector xpu_quant_post_dynamic_op_types_; // LITE OPENCL SETTINGS bool use_opencl_{false}; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index bce79c27c0585..e00c22423eb28 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -767,6 +767,11 @@ void BindAnalysisConfig(py::module *m) { .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, py::arg("device_id") = 0) + .def( + "set_xpu_config", + &AnalysisConfig::SetXpuConfig, + py::arg("quant_post_dynamic_weight_bits") = -1, + py::arg("quant_post_dynamic_op_types") = std::vector({})) .def("enable_custom_device", &AnalysisConfig::EnableCustomDevice, py::arg("device_type"),