From 8129bfbca13d99393730660c0ae7c657c8fd2868 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 25 Apr 2023 03:21:23 +0000 Subject: [PATCH 1/3] xpu quant weight only --- .../fused_multi_transformer_xpu_quant_pass.cc | 65 ++++++++++--------- paddle/fluid/framework/ir/xpu/pass_utils.cc | 7 ++ paddle/fluid/framework/ir/xpu/quant_utils.cc | 13 ++++ paddle/fluid/inference/analysis/argument.h | 3 + .../inference/analysis/ir_pass_manager.cc | 3 + paddle/fluid/inference/api/analysis_config.cc | 9 +++ .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_analysis_config.h | 3 + paddle/fluid/pybind/inference_api.cc | 3 + 9 files changed, 77 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc index fab466a50637e..ba57ad3b4aaf2 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc @@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, with_time_step, with_seq_lengths, with_src_mask); + bool quant_weight_only = + Has("quant_weight_only") && Get("quant_weight_only"); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, // quant weight nodes // w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight] std::vector> w_nodes_vec(4); - std::vector> w_int16_nodes_vec(4); + std::vector> w_intx_nodes_vec(4); std::vector> w_max_nodes_vec(4); - std::vector> w_int16_names_vec(4); + std::vector> w_intx_names_vec(4); std::vector> w_max_names_vec(4); auto quant_func = [&](const std::string& input_name, std::vector* w_nodes, - std::vector* w_int16_nodes, + std::vector* w_intx_nodes, std::vector* w_max_nodes, - std::vector* w_int16_names, + std::vector* w_intx_names, std::vector* w_max_names, bool need_transpose) { - typedef int16_t TW; - auto w_names = fused_mt->Op()->Input(input_name); for (auto w_name : w_names) { Node* w_node = FindNodeWithName(graph, w_name); - Node* w_int16 = nullptr; + Node* w_intx = nullptr; Node* w_max = nullptr; PADDLE_ENFORCE_NE( w_node, nullptr, platform::errors::Fatal("w node should not be nullptr")); - PrepareWeight( - graph, scope, block, w_node, &w_int16, &w_max, need_transpose); + if (quant_weight_only) { + PrepareWeight( + graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + } else { + PrepareWeight( + graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + } w_nodes->push_back(w_node); - w_int16_nodes->push_back(w_int16); + w_intx_nodes->push_back(w_intx); w_max_nodes->push_back(w_max); } for (size_t i = 0; i < w_names.size(); ++i) { - w_int16_names->push_back(w_int16_nodes->at(i)->Name()); + w_intx_names->push_back(w_intx_nodes->at(i)->Name()); w_max_names->push_back(w_max_nodes->at(i)->Name()); } PADDLE_ENFORCE_EQ( @@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, static_cast(w_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), - w_int16_nodes->size(), + w_intx_nodes->size(), platform::errors::Fatal( - "The size of w_names(%d) should be equal to w_int16_nodes(%d)", + "The size of w_names(%d) should be equal to w_intx_nodes(%d)", static_cast(w_names.size()), - static_cast(w_int16_nodes->size()))); + static_cast(w_intx_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), w_max_nodes->size(), @@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, static_cast(w_max_nodes->size()))); PADDLE_ENFORCE_EQ( w_names.size(), - w_int16_names->size(), + w_intx_names->size(), platform::errors::Fatal( - "The size of w_names(%d) should be equal to w_int16_names(%d)", + "The size of w_names(%d) should be equal to w_intx_names(%d)", static_cast(w_names.size()), - static_cast(w_int16_names->size()))); + static_cast(w_intx_names->size()))); PADDLE_ENFORCE_EQ( w_names.size(), w_max_names->size(), @@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, }; quant_func("QKVW", &(w_nodes_vec[0]), - &(w_int16_nodes_vec[0]), + &(w_intx_nodes_vec[0]), &(w_max_nodes_vec[0]), - &(w_int16_names_vec[0]), + &(w_intx_names_vec[0]), &(w_max_names_vec[0]), false); quant_func("OutLinearW", &(w_nodes_vec[1]), - &(w_int16_nodes_vec[1]), + &(w_intx_nodes_vec[1]), &(w_max_nodes_vec[1]), - &(w_int16_names_vec[1]), + &(w_intx_names_vec[1]), &(w_max_names_vec[1]), true); quant_func("FFN1Weight", &(w_nodes_vec[2]), - &(w_int16_nodes_vec[2]), + &(w_intx_nodes_vec[2]), &(w_max_nodes_vec[2]), - &(w_int16_names_vec[2]), + &(w_intx_names_vec[2]), &(w_max_names_vec[2]), true); quant_func("FFN2Weight", &(w_nodes_vec[3]), - &(w_int16_nodes_vec[3]), + &(w_intx_nodes_vec[3]), &(w_max_nodes_vec[3]), - &(w_int16_names_vec[3]), + &(w_intx_names_vec[3]), &(w_max_names_vec[3]), true); @@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, name_caches.at("CacheKVOut")); fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out")); - fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]); + fused_mt_xpu_op_desc->SetInput("qkvw", w_intx_names_vec[0]); fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]); - fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]); + fused_mt_xpu_op_desc->SetInput("out_linear_w", w_intx_names_vec[1]); fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]); - fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]); + fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_intx_names_vec[2]); fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]); - fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]); + fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_intx_names_vec[3]); fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]); if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) { fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0); @@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, } // link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to // fused_mt_xpu - for (auto nodes : w_int16_nodes_vec) { + for (auto nodes : w_intx_nodes_vec) { for (auto node : nodes) { IR_NODE_LINK_TO(node, fused_mt); } diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index bec10d421c7c2..aaa117d363a5f 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -193,6 +193,13 @@ template void PrepareWeight(Graph* graph, Node** dst, Node** dst_max, bool transpose); +template void PrepareWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* src, + Node** dst, + Node** dst_max, + bool transpose); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index ae697708ddcd7..d075d42d29506 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -207,6 +207,16 @@ void QuantFP32ToIntX(const float* src_ptr, } } +template <> +void QuantFP32ToIntX(const float* src_ptr, + int8_t* dst_ptr, + float max_val, + int numel) { + for (int i = 0; i < numel; i++) { + dst_ptr[i] = Fp32ToIntx(src_ptr[i], max_val); + } +} + template void PrepareWeight(phi::DenseTensor* weight, phi::DenseTensor* weight_max, @@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight, template void PrepareWeight(phi::DenseTensor* weight, phi::DenseTensor* weight_max, bool transpose); +template void PrepareWeight(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose); } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index b02e2e05efd6b..6243299007674 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -289,6 +289,9 @@ struct Argument { DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool); DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int); DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool); + DECL_ARGUMENT_FIELD(xpu_enable_quant_weight_only, + XpuEnableQuantWeightOnly, + bool); DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index c3c2fb6d80ffd..d14c2357d6e51 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -308,6 +308,9 @@ void IRPassManager::CreatePasses(Argument *argument, } bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding(); pass->Set("use_fc_padding", new bool(use_fc_padding)); + } else if (pass_name == "fused_multi_transformer_xpu_quant_pass") { + pass->Set("quant_weight_only", + new bool(argument->xpu_enable_quant_weight_only())); } pre_pass = pass_name; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index b0f53c1f639ac..506731cb4062c 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -195,6 +195,11 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { Update(); } +void AnalysisConfig::SetXpuConfig(bool xpu_enable_quant_weight_only) { + xpu_enable_quant_weight_only_ = xpu_enable_quant_weight_only; + Update(); +} + void AnalysisConfig::EnableCustomDevice(const std::string &device_type, int device_id, Precision precision_mode) { @@ -488,6 +493,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(xpu_precision_); CP_MEMBER(xpu_adaptive_seqlen_); CP_MEMBER(xpu_enable_multi_stream_); + CP_MEMBER(xpu_enable_quant_weight_only_); // Lite OpenCL Related CP_MEMBER(use_opencl_); @@ -1090,6 +1096,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << xpu_precision_; ss << xpu_adaptive_seqlen_; ss << xpu_enable_multi_stream_; + ss << xpu_enable_quant_weight_only_; ss << use_npu_; ss << npu_device_id_; @@ -1330,6 +1337,8 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)}); os.InsertRow( {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)}); + os.InsertRow({"xpu_enable_quant_weight_only", + std::to_string(xpu_enable_quant_weight_only_)}); } os.InsetDivider(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b0ef79a0c7bdf..1eb4c5d271994 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1426,6 +1426,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_); argument_->SetXpuDeviceId(config_.xpu_device_id_); argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_); + argument_->SetXpuEnableQuantWeightOnly(config_.xpu_enable_quant_weight_only_); #endif auto *pass_builder = config_.pass_builder(); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 10b92debdace1..4f4efd5d52768 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -288,6 +288,8 @@ struct PD_INFER_DECL AnalysisConfig { bool adaptive_seqlen = false, bool enable_multi_stream = false); + void SetXpuConfig(bool xpu_enable_quant_weight_only = false); + /// /// \brief configs of IPU /// @@ -1181,6 +1183,7 @@ struct PD_INFER_DECL AnalysisConfig { std::string xpu_precision_; bool xpu_adaptive_seqlen_; bool xpu_enable_multi_stream_; + bool xpu_enable_quant_weight_only_{false}; // LITE OPENCL SETTINGS bool use_opencl_{false}; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index bce79c27c0585..49ae120d7afc5 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -767,6 +767,9 @@ void BindAnalysisConfig(py::module *m) { .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, py::arg("device_id") = 0) + .def("set_xpu_config", + &AnalysisConfig::SetXpuConfig, + py::arg("xpu_enable_quant_weight_only") = false) .def("enable_custom_device", &AnalysisConfig::EnableCustomDevice, py::arg("device_type"), From adbc9e625af63fbd1ee23c1e1cb4c33f6353cba6 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 25 Apr 2023 05:40:43 +0000 Subject: [PATCH 2/3] test --- paddle/fluid/inference/api/paddle_analysis_config.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 4f4efd5d52768..d8887cfc942de 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -288,6 +288,12 @@ struct PD_INFER_DECL AnalysisConfig { bool adaptive_seqlen = false, bool enable_multi_stream = false); + /// + /// \brief configs of XPU + /// + /// \param xpu_enable_quant_weight_only Whether to enable weight only optimize + /// on fused_multi_transformer. + /// void SetXpuConfig(bool xpu_enable_quant_weight_only = false); /// From c819b61980df56155ffac6a8d02daec73e725391 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 26 Apr 2023 10:47:28 +0000 Subject: [PATCH 3/3] fix --- .../fused_multi_transformer_xpu_quant_pass.cc | 6 ++--- paddle/fluid/inference/analysis/argument.h | 9 ++++--- .../inference/analysis/ir_pass_manager.cc | 9 +++++-- paddle/fluid/inference/api/analysis_config.cc | 24 ++++++++++++++----- .../fluid/inference/api/analysis_predictor.cc | 5 +++- .../inference/api/paddle_analysis_config.h | 13 ++++++---- paddle/fluid/pybind/inference_api.cc | 8 ++++--- 7 files changed, 52 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc index ba57ad3b4aaf2..b9868a81135d4 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc @@ -280,8 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, with_time_step, with_seq_lengths, with_src_mask); - bool quant_weight_only = - Has("quant_weight_only") && Get("quant_weight_only"); + int quant_weight_bits = + Has("quant_weight_bits") ? Get("quant_weight_bits") : -1; int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -334,7 +334,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, w_node, nullptr, platform::errors::Fatal("w node should not be nullptr")); - if (quant_weight_only) { + if (quant_weight_bits == 8) { PrepareWeight( graph, scope, block, w_node, &w_intx, &w_max, need_transpose); } else { diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 8f28c10a2602f..88e6749223008 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -289,9 +289,12 @@ struct Argument { DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool); DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int); DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool); - DECL_ARGUMENT_FIELD(xpu_enable_quant_weight_only, - XpuEnableQuantWeightOnly, - bool); + DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_bits, + XpuQuantPostDynamicWeightBits, + int); + DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types, + XpuQuantPostDynamicOpTypss, + std::vector); DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index d14c2357d6e51..1d87edcd3404c 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -309,8 +309,13 @@ void IRPassManager::CreatePasses(Argument *argument, bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding(); pass->Set("use_fc_padding", new bool(use_fc_padding)); } else if (pass_name == "fused_multi_transformer_xpu_quant_pass") { - pass->Set("quant_weight_only", - new bool(argument->xpu_enable_quant_weight_only())); + auto op_types = argument->xpu_quant_post_dynamic_op_types(); + if (std::count(op_types.begin(), + op_types.end(), + "fused_multi_transformer") > 0) { + pass->Set("quant_weight_bits", + new int(argument->xpu_quant_post_dynamic_weight_bits())); + } } pre_pass = pass_name; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index c9d737cae423d..3fa947ce27daf 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -196,8 +196,11 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { Update(); } -void AnalysisConfig::SetXpuConfig(bool xpu_enable_quant_weight_only) { - xpu_enable_quant_weight_only_ = xpu_enable_quant_weight_only; +void AnalysisConfig::SetXpuConfig( + int quant_post_dynamic_weight_bits, + const std::vector &quant_post_dynamic_op_types) { + xpu_quant_post_dynamic_weight_bits_ = quant_post_dynamic_weight_bits; + xpu_quant_post_dynamic_op_types_ = quant_post_dynamic_op_types; Update(); } @@ -494,7 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(xpu_precision_); CP_MEMBER(xpu_adaptive_seqlen_); CP_MEMBER(xpu_enable_multi_stream_); - CP_MEMBER(xpu_enable_quant_weight_only_); + CP_MEMBER(xpu_quant_post_dynamic_weight_bits_); + CP_MEMBER(xpu_quant_post_dynamic_op_types_); // Lite OpenCL Related CP_MEMBER(use_opencl_); @@ -1097,7 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << xpu_precision_; ss << xpu_adaptive_seqlen_; ss << xpu_enable_multi_stream_; - ss << xpu_enable_quant_weight_only_; + ss << xpu_quant_post_dynamic_weight_bits_; + for (auto op_type : xpu_quant_post_dynamic_op_types_) { + ss << op_type; + } ss << use_npu_; ss << npu_device_id_; @@ -1338,8 +1345,13 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)}); os.InsertRow( {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)}); - os.InsertRow({"xpu_enable_quant_weight_only", - std::to_string(xpu_enable_quant_weight_only_)}); + os.InsertRow({"xpu_quant_post_dynamic_weight_bits", + std::to_string(xpu_quant_post_dynamic_weight_bits_)}); + std::vector op_types{"xpu_quant_post_dynamic_op_types"}; + for (auto op_type : xpu_quant_post_dynamic_op_types_) { + op_types.push_back(op_type); + } + os.InsertRow(op_types); } os.InsetDivider(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 553d8dcfbd45b..4ef15fbbc15dc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1426,7 +1426,10 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_); argument_->SetXpuDeviceId(config_.xpu_device_id_); argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_); - argument_->SetXpuEnableQuantWeightOnly(config_.xpu_enable_quant_weight_only_); + argument_->SetXpuQuantPostDynamicWeightBits( + config_.xpu_quant_post_dynamic_weight_bits_); + argument_->SetXpuQuantPostDynamicOpTypss( + config_.xpu_quant_post_dynamic_op_types_); #endif auto *pass_builder = config_.pass_builder(); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index d8887cfc942de..3300fe8a9b8dc 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -291,10 +291,14 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \brief configs of XPU /// - /// \param xpu_enable_quant_weight_only Whether to enable weight only optimize - /// on fused_multi_transformer. + /// \param quant_post_dynamic_weight_bits Weight bits used in dynamic post + /// quantization. Optional value: -1, 8, 16. Default value is -1, means using + /// the recommended way. \param quant_post_dynamic_op_types Ops used in + /// dynamic post quantization. /// - void SetXpuConfig(bool xpu_enable_quant_weight_only = false); + void SetXpuConfig( + int quant_post_dynamic_weight_bits = -1, + const std::vector& quant_post_dynamic_op_types = {}); /// /// \brief configs of IPU @@ -1189,7 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig { std::string xpu_precision_; bool xpu_adaptive_seqlen_; bool xpu_enable_multi_stream_; - bool xpu_enable_quant_weight_only_{false}; + int xpu_quant_post_dynamic_weight_bits_{-1}; + std::vector xpu_quant_post_dynamic_op_types_; // LITE OPENCL SETTINGS bool use_opencl_{false}; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 49ae120d7afc5..e00c22423eb28 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -767,9 +767,11 @@ void BindAnalysisConfig(py::module *m) { .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, py::arg("device_id") = 0) - .def("set_xpu_config", - &AnalysisConfig::SetXpuConfig, - py::arg("xpu_enable_quant_weight_only") = false) + .def( + "set_xpu_config", + &AnalysisConfig::SetXpuConfig, + py::arg("quant_post_dynamic_weight_bits") = -1, + py::arg("quant_post_dynamic_op_types") = std::vector({})) .def("enable_custom_device", &AnalysisConfig::EnableCustomDevice, py::arg("device_type"),