Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu quant weight only #53306

Merged
merged 4 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
with_time_step,
with_seq_lengths,
with_src_mask);
int quant_weight_bits =
Has("quant_weight_bits") ? Get<int>("quant_weight_bits") : -1;

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -312,36 +314,39 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
// quant weight nodes
// w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight]
std::vector<std::vector<Node*>> w_nodes_vec(4);
std::vector<std::vector<Node*>> w_int16_nodes_vec(4);
std::vector<std::vector<Node*>> w_intx_nodes_vec(4);
std::vector<std::vector<Node*>> w_max_nodes_vec(4);
std::vector<std::vector<std::string>> w_int16_names_vec(4);
std::vector<std::vector<std::string>> w_intx_names_vec(4);
std::vector<std::vector<std::string>> w_max_names_vec(4);
auto quant_func = [&](const std::string& input_name,
std::vector<Node*>* w_nodes,
std::vector<Node*>* w_int16_nodes,
std::vector<Node*>* w_intx_nodes,
std::vector<Node*>* w_max_nodes,
std::vector<std::string>* w_int16_names,
std::vector<std::string>* w_intx_names,
std::vector<std::string>* w_max_names,
bool need_transpose) {
typedef int16_t TW;

auto w_names = fused_mt->Op()->Input(input_name);
for (auto w_name : w_names) {
Node* w_node = FindNodeWithName(graph, w_name);
Node* w_int16 = nullptr;
Node* w_intx = nullptr;
Node* w_max = nullptr;
PADDLE_ENFORCE_NE(
w_node,
nullptr,
platform::errors::Fatal("w node should not be nullptr"));
PrepareWeight<TW>(
graph, scope, block, w_node, &w_int16, &w_max, need_transpose);
if (quant_weight_bits == 8) {
PrepareWeight<int8_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
} else {
PrepareWeight<int16_t>(
graph, scope, block, w_node, &w_intx, &w_max, need_transpose);
}
w_nodes->push_back(w_node);
w_int16_nodes->push_back(w_int16);
w_intx_nodes->push_back(w_intx);
w_max_nodes->push_back(w_max);
}
for (size_t i = 0; i < w_names.size(); ++i) {
w_int16_names->push_back(w_int16_nodes->at(i)->Name());
w_intx_names->push_back(w_intx_nodes->at(i)->Name());
w_max_names->push_back(w_max_nodes->at(i)->Name());
}
PADDLE_ENFORCE_EQ(
Expand All @@ -353,11 +358,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_nodes->size(),
w_intx_nodes->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_nodes(%d)",
"The size of w_names(%d) should be equal to w_intx_nodes(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_nodes->size())));
static_cast<int>(w_intx_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_nodes->size(),
Expand All @@ -367,11 +372,11 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
static_cast<int>(w_max_nodes->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_int16_names->size(),
w_intx_names->size(),
platform::errors::Fatal(
"The size of w_names(%d) should be equal to w_int16_names(%d)",
"The size of w_names(%d) should be equal to w_intx_names(%d)",
static_cast<int>(w_names.size()),
static_cast<int>(w_int16_names->size())));
static_cast<int>(w_intx_names->size())));
PADDLE_ENFORCE_EQ(
w_names.size(),
w_max_names->size(),
Expand All @@ -382,30 +387,30 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
};
quant_func("QKVW",
&(w_nodes_vec[0]),
&(w_int16_nodes_vec[0]),
&(w_intx_nodes_vec[0]),
&(w_max_nodes_vec[0]),
&(w_int16_names_vec[0]),
&(w_intx_names_vec[0]),
&(w_max_names_vec[0]),
false);
quant_func("OutLinearW",
&(w_nodes_vec[1]),
&(w_int16_nodes_vec[1]),
&(w_intx_nodes_vec[1]),
&(w_max_nodes_vec[1]),
&(w_int16_names_vec[1]),
&(w_intx_names_vec[1]),
&(w_max_names_vec[1]),
true);
quant_func("FFN1Weight",
&(w_nodes_vec[2]),
&(w_int16_nodes_vec[2]),
&(w_intx_nodes_vec[2]),
&(w_max_nodes_vec[2]),
&(w_int16_names_vec[2]),
&(w_intx_names_vec[2]),
&(w_max_names_vec[2]),
true);
quant_func("FFN2Weight",
&(w_nodes_vec[3]),
&(w_int16_nodes_vec[3]),
&(w_intx_nodes_vec[3]),
&(w_max_nodes_vec[3]),
&(w_int16_names_vec[3]),
&(w_intx_names_vec[3]),
&(w_max_names_vec[3]),
true);

Expand Down Expand Up @@ -482,13 +487,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
name_caches.at("CacheKVOut"));
fused_mt_xpu_op_desc->SetOutput("out", name_caches.at("Out"));

fused_mt_xpu_op_desc->SetInput("qkvw", w_int16_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw", w_intx_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("qkvw_max", w_max_names_vec[0]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_int16_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_w", w_intx_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("out_linear_wmax", w_max_names_vec[1]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_int16_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight", w_intx_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn1_weight_max", w_max_names_vec[2]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_int16_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight", w_intx_names_vec[3]);
fused_mt_xpu_op_desc->SetInput("ffn2_weight_max", w_max_names_vec[3]);
if (!fused_mt_xpu_op_desc->HasAttr("rotary_emb_dims")) {
fused_mt_xpu_op_desc->SetAttr("rotary_emb_dims", 0);
Expand All @@ -501,7 +506,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
}
// link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to
// fused_mt_xpu
for (auto nodes : w_int16_nodes_vec) {
for (auto nodes : w_intx_nodes_vec) {
for (auto node : nodes) {
IR_NODE_LINK_TO(node, fused_mt);
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/ir/xpu/pass_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ template void PrepareWeight<int16_t>(Graph* graph,
Node** dst,
Node** dst_max,
bool transpose);
template void PrepareWeight<int8_t>(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);

void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) {
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/xpu/quant_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
}
}

template <>
void QuantFP32ToIntX<int8_t>(const float* src_ptr,
int8_t* dst_ptr,
float max_val,
int numel) {
for (int i = 0; i < numel; i++) {
dst_ptr[i] = Fp32ToIntx<int8_t, 127>(src_ptr[i], max_val);
}
}

template <typename T>
void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
Expand Down Expand Up @@ -253,6 +263,9 @@ void PrepareWeight(phi::DenseTensor* weight,
template void PrepareWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);
template void PrepareWeight<int8_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);

} // namespace ir
} // namespace framework
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_bits,
XpuQuantPostDynamicWeightBits,
int);
DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types,
XpuQuantPostDynamicOpTypss,
std::vector<std::string>);

DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);

Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,14 @@ void IRPassManager::CreatePasses(Argument *argument,
}
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
} else if (pass_name == "fused_multi_transformer_xpu_quant_pass") {
auto op_types = argument->xpu_quant_post_dynamic_op_types();
if (std::count(op_types.begin(),
op_types.end(),
"fused_multi_transformer") > 0) {
pass->Set("quant_weight_bits",
new int(argument->xpu_quant_post_dynamic_weight_bits()));
}
}
pre_pass = pass_name;

Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/inference/api/analysis_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) {
Update();
}

void AnalysisConfig::SetXpuConfig(
int quant_post_dynamic_weight_bits,
const std::vector<std::string> &quant_post_dynamic_op_types) {
xpu_quant_post_dynamic_weight_bits_ = quant_post_dynamic_weight_bits;
xpu_quant_post_dynamic_op_types_ = quant_post_dynamic_op_types;
Update();
}

void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
int device_id,
Precision precision_mode) {
Expand Down Expand Up @@ -489,6 +497,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(xpu_precision_);
CP_MEMBER(xpu_adaptive_seqlen_);
CP_MEMBER(xpu_enable_multi_stream_);
CP_MEMBER(xpu_quant_post_dynamic_weight_bits_);
CP_MEMBER(xpu_quant_post_dynamic_op_types_);

// Lite OpenCL Related
CP_MEMBER(use_opencl_);
Expand Down Expand Up @@ -1091,6 +1101,10 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << xpu_precision_;
ss << xpu_adaptive_seqlen_;
ss << xpu_enable_multi_stream_;
ss << xpu_quant_post_dynamic_weight_bits_;
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
ss << op_type;
}

ss << use_npu_;
ss << npu_device_id_;
Expand Down Expand Up @@ -1331,6 +1345,13 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
os.InsertRow(
{"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
os.InsertRow({"xpu_quant_post_dynamic_weight_bits",
std::to_string(xpu_quant_post_dynamic_weight_bits_)});
std::vector<std::string> op_types{"xpu_quant_post_dynamic_op_types"};
for (auto op_type : xpu_quant_post_dynamic_op_types_) {
op_types.push_back(op_type);
}
os.InsertRow(op_types);
}
os.InsetDivider();

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_->SetXpuDeviceId(config_.xpu_device_id_);
argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_->SetXpuQuantPostDynamicWeightBits(
config_.xpu_quant_post_dynamic_weight_bits_);
argument_->SetXpuQuantPostDynamicOpTypss(
config_.xpu_quant_post_dynamic_op_types_);
#endif

auto *pass_builder = config_.pass_builder();
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@ struct PD_INFER_DECL AnalysisConfig {
bool adaptive_seqlen = false,
bool enable_multi_stream = false);

///
/// \brief configs of XPU
///
/// \param quant_post_dynamic_weight_bits Weight bits used in dynamic post
/// quantization. Optional value: -1, 8, 16. Default value is -1, means using
/// the recommended way. \param quant_post_dynamic_op_types Ops used in
/// dynamic post quantization.
///
void SetXpuConfig(
int quant_post_dynamic_weight_bits = -1,
const std::vector<std::string>& quant_post_dynamic_op_types = {});

///
/// \brief configs of IPU
///
Expand Down Expand Up @@ -1181,6 +1193,8 @@ struct PD_INFER_DECL AnalysisConfig {
std::string xpu_precision_;
bool xpu_adaptive_seqlen_;
bool xpu_enable_multi_stream_;
int xpu_quant_post_dynamic_weight_bits_{-1};
std::vector<std::string> xpu_quant_post_dynamic_op_types_;

// LITE OPENCL SETTINGS
bool use_opencl_{false};
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/inference_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,11 @@ void BindAnalysisConfig(py::module *m) {
.def("set_xpu_device_id",
&AnalysisConfig::SetXpuDeviceId,
py::arg("device_id") = 0)
.def(
"set_xpu_config",
&AnalysisConfig::SetXpuConfig,
py::arg("quant_post_dynamic_weight_bits") = -1,
py::arg("quant_post_dynamic_op_types") = std::vector<std::string>({}))
.def("enable_custom_device",
&AnalysisConfig::EnableCustomDevice,
py::arg("device_type"),
Expand Down