diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 88b7c4ebcc86d..d09b6bd373440 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -221,6 +221,7 @@ if(WITH_XPU) SRCS xpu/pass_utils.cc DEPS pass) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) + pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc index e3c2e6cef2114..285c25c6a5e9d 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -30,46 +30,50 @@ namespace ir { void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "delete_dropout_op_pattern"; FusePassBase::Init(pattern_name, graph); + int found_subgraph_count = 0; - GraphPatternDetector gpd; - patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name); - pattern(); + for (auto with_mask : {true, false}) { + GraphPatternDetector gpd; + patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), + pattern_name); + pattern(with_mask); - int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE(dropout_op_x); - GET_IR_NODE(dropout_op); - GET_IR_NODE(dropout_op_out); - GET_IR_NODE(dropout_op_mask); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE(dropout_op_x); + GET_IR_NODE(dropout_op); + GET_IR_NODE(dropout_op_out); - // link dropout_op_out to pre_op - auto dropout_op_x_name = dropout_op_x->Var()->Name(); - auto dropout_op_out_name = dropout_op_out->Var()->Name(); - auto pre_ops = dropout_op_x->inputs; - if (pre_ops.empty()) return; - auto pre_op_desc = pre_ops[0]->Op(); - auto pre_op_outs = pre_op_desc->Outputs(); - for (auto& out_var : pre_op_outs) { - auto names = out_var.second; - for (size_t i = 0; i < names.size(); i++) { - if (names[i] == dropout_op_x_name) { - names[i] = dropout_op_out_name; - pre_op_desc->SetOutput(out_var.first, names); - break; + // link dropout_op_x to next_op + auto dropout_op_x_name = dropout_op_x->Var()->Name(); + auto dropout_op_out_name = dropout_op_out->Var()->Name(); + auto next_op_nodes = dropout_op_out->outputs; + for (auto next_op_node : next_op_nodes) { + auto next_op_desc = next_op_node->Op(); + auto next_op_inputs = next_op_desc->Inputs(); + for (auto& input_var : next_op_inputs) { + auto names = input_var.second; + for (size_t i = 0; i < names.size(); i++) { + if (names[i] == dropout_op_out_name) { + names[i] = dropout_op_x_name; + next_op_desc->SetInput(input_var.first, names); + break; + } + } } + IR_NODE_LINK_TO(dropout_op_x, next_op_node); } - } - IR_NODE_LINK_TO(pre_ops[0], dropout_op_out); - - // delete useless node - std::unordered_set delete_nodes{ - dropout_op_x, dropout_op, dropout_op_mask}; - GraphSafeRemoveNodes(graph, delete_nodes); - found_subgraph_count++; - }; - - gpd(graph, handler); + // delete useless node + std::unordered_set delete_nodes{dropout_op, dropout_op_out}; + if (with_mask) { + GET_IR_NODE(dropout_op_mask); + delete_nodes.insert(dropout_op_mask); + } + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + gpd(graph, handler); + } AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4a1f80916ef2b..5e3f2dc42ce1b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()( return concat_out; } -void patterns::DeleteDropoutOpPattern::operator()() { +void patterns::DeleteDropoutOpPattern::operator()(bool with_mask) { auto dropout_op_x = pattern->NewNode(dropout_op_x_repr()) ->assert_is_op_input("dropout", "X") ->AsInput(); @@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() { std::string("upscale_in_train")); auto dropout_op_out = pattern->NewNode(dropout_op_out_repr()) ->assert_is_op_output("dropout", "Out"); - auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr()) - ->assert_is_op_output("dropout", "Mask"); - dropout_op->LinksFrom({dropout_op_x}) - .LinksTo({dropout_op_out, dropout_op_mask}); + if (with_mask) { + auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr()) + ->assert_is_op_output("dropout", "Mask"); + dropout_op->LinksFrom({dropout_op_x}) + .LinksTo({dropout_op_out, dropout_op_mask}); + } else { + dropout_op->LinksFrom({dropout_op_x}).LinksTo({dropout_op_out}); + } } void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f25260e0ce04e..80474d0d67acb 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase { DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {} - void operator()(); + void operator()(bool with_mask); PATTERN_DECL_NODE(dropout_op_x); PATTERN_DECL_NODE(dropout_op); diff --git a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc new file mode 100644 index 0000000000000..05975b6a1c24c --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc @@ -0,0 +1,313 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +static bool GetBoolFromEnv(const std::string& str, bool def = false) { + char* variable = std::getenv(str.c_str()); + if (!variable) { + return def; + } + if (strcmp(variable, "false") == 0 || strcmp(variable, "0") == 0) { + return false; + } else { + return true; + } +} + +namespace patterns { + +struct EmbeddingWithEltwiseAddXPUPattern : public PatternBase { + EmbeddingWithEltwiseAddXPUPattern(PDPattern* pattern, + const std::string& name_scope, + int n_embedding_, + const std::string& op_type, + const std::string& pre_op_type); + + // declare operator node's name + PATTERN_DECL_NODE(embedding0); + PATTERN_DECL_NODE(embedding1); + PATTERN_DECL_NODE(ewadd01); + // declare variable node's name + PATTERN_DECL_NODE(x0); + PATTERN_DECL_NODE(x1); + PATTERN_DECL_NODE(table0); + PATTERN_DECL_NODE(table1); + PATTERN_DECL_NODE(embedding_out0); + PATTERN_DECL_NODE(embedding_out1); + PATTERN_DECL_NODE(ewadd01_out); + + std::unordered_map node_reprs; + + private: + int n_embedding_; + std::string op_type_; + std::string pre_op_type_; +}; + +EmbeddingWithEltwiseAddXPUPattern::EmbeddingWithEltwiseAddXPUPattern( + PDPattern* pattern, + const std::string& name_scope, + int n_embedding, + const std::string& op_type, + const std::string& pre_op_type) + : PatternBase(pattern, name_scope, name_scope), + n_embedding_(n_embedding), + op_type_(op_type), + pre_op_type_(pre_op_type) { + for (int i = 0; i < n_embedding; i++) { + node_reprs["x" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "x" + std::to_string(i)); + node_reprs["table" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "table" + std::to_string(i)); + node_reprs["embedding" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "embedding" + std::to_string(i)); + node_reprs["embedding_out" + std::to_string(i)] = PDNodeName( + name_scope_, repr_, id_, "embedding_out" + std::to_string(i)); + if (i - 1 >= 0) { + auto ewadd_name = string::Sprintf("ewadd%d%d", i - 1, i); + node_reprs[ewadd_name] = PDNodeName(name_scope_, repr_, id_, ewadd_name); + auto ewadd_out_name = string::Sprintf("ewadd%d%d_out", i - 1, i); + node_reprs[ewadd_out_name] = + PDNodeName(name_scope_, repr_, id_, ewadd_out_name); + } + } + PDNode* x0 = pattern->NewNode(x0_repr()) + ->assert_is_op_input(op_type_, "Ids") + ->assert_var_not_persistable() + ->AsInput(); + PDNode* x1 = pattern->NewNode(x1_repr()) + ->assert_is_op_input(op_type_, "Ids") + ->assert_var_not_persistable() + ->AsInput(); + PDNode* embedding0 = + pattern->NewNode(embedding0_repr())->assert_is_op(op_type_); + auto* table0 = pattern->NewNode(table0_repr()) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* embedding_out0 = pattern->NewNode(embedding_out0_repr()) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "X"); + auto* table1 = pattern->NewNode(table1_repr()) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* embedding1 = + pattern->NewNode(embedding1_repr())->assert_is_op(op_type_); + + auto* embedding_out1 = pattern->NewNode(embedding_out1_repr()) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "Y"); + auto* ewadd01 = + pattern->NewNode(ewadd01_repr())->assert_is_op("elementwise_add"); + auto* ewadd01_out = pattern->NewNode(ewadd01_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + embedding0->LinksFrom({x0, table0}); + embedding1->LinksFrom({x1, table1}); + embedding0->LinksTo({embedding_out0}); + embedding1->LinksTo({embedding_out1}); + ewadd01->LinksFrom({embedding_out0, embedding_out1}); + ewadd01->LinksTo({ewadd01_out}); + + auto* last_ewadd_out = ewadd01_out; + for (int i = 2; i < n_embedding; ++i) { + auto x_name = node_reprs["x" + std::to_string(i)]; + auto table_name = node_reprs["table" + std::to_string(i)]; + auto embedding_name = node_reprs["embedding" + std::to_string(i)]; + auto embedding_out_name = node_reprs["embedding_out" + std::to_string(i)]; + auto* new_table = pattern->NewNode(table_name) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* new_embedding = + pattern->NewNode(embedding_name)->assert_is_op(op_type_); + auto* new_embedding_out = pattern->NewNode(embedding_out_name) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "Y"); + auto* new_x = pattern->NewNode(x_name) + ->assert_is_op_input(op_type_, "Ids") + ->AsInput(); + new_embedding->LinksFrom({new_x, new_table}); + new_embedding->LinksTo({new_embedding_out}); + auto ewadd_name = + node_reprs["ewadd" + std::to_string(i - 1) + std::to_string(i)]; + auto ewadd_out_name = node_reprs["ewadd" + std::to_string(i - 1) + + std::to_string(i) + "_out"]; + auto* new_ewadd = + pattern->NewNode(ewadd_name)->assert_is_op("elementwise_add"); + auto* new_ewadd_out = pattern->NewNode(ewadd_out_name) + ->assert_is_op_output("elementwise_add", "Out"); + new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out}); + new_ewadd->LinksTo({new_ewadd_out}); + last_ewadd_out = new_ewadd_out; + } + last_ewadd_out->AsOutput(); +} + +} // namespace patterns + +class EmbeddingWithEltwiseAddXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void ApplyImpl(ir::Graph* graph, + int n_embedding, + const std::string op_type, + const std::string pre_op_type) const; + + const std::string name_scope_{"embedding_with_eltwise_add_xpu_fuse_pass"}; +}; + +void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init(name_scope_, graph); + std::vector pre_op_types{"reshape2", "squeeze2", ""}; + std::vector op_types{"lookup_table", "lookup_table_v2"}; + for (auto& pre_op_type : pre_op_types) { + for (int n_embedding : {4, 3, 2}) { + for (auto& op_type : op_types) { + ApplyImpl(graph, n_embedding, op_type, pre_op_type); + } + } + } +} + +void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl( + ir::Graph* graph, + int n_embedding, + const std::string op_type, + const std::string pre_op_type) const { + GraphPatternDetector gpd; + patterns::EmbeddingWithEltwiseAddXPUPattern pattern( + gpd.mutable_pattern(), name_scope_, n_embedding, op_type, pre_op_type); + int found_subgraph_count = 0; +#define GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(name, rt_node, pat) \ + PADDLE_ENFORCE_NE( \ + subgraph.count(pat.PatternBase::pattern->RetrieveNode(name)), \ + 0UL, \ + platform::errors::NotFound("Node not found for PDNode %s", name)); \ + Node* rt_node = subgraph.at(pat.PatternBase::pattern->RetrieveNode(name)); \ + PADDLE_ENFORCE_NOT_NULL( \ + rt_node, \ + platform::errors::NotFound("node %s not exists in the sub-graph", \ + #rt_node)); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + std::vector x_names; + std::vector table_names; + std::vector x_nodes; + std::vector table_nodes; + std::vector embedding_nodes; + auto output_name = pattern.node_reprs[string::Sprintf( + "ewadd%d%d_out", n_embedding - 2, n_embedding - 1)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(output_name, output_node, pattern) + std::unordered_set delete_nodes; + for (int i = 0; i < n_embedding; ++i) { + // Ids + auto x_name = pattern.node_reprs["x" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(x_name, x_node, pattern) + x_nodes.push_back(x_node); + x_names.push_back(x_node->Name()); + // Tables + auto table_name = pattern.node_reprs["table" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(table_name, table_node, pattern) + table_nodes.push_back(table_node); + table_names.push_back(table_node->Name()); + // Embedding + auto embedding_name = pattern.node_reprs["embedding" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(embedding_name, embedding_node, pattern) + embedding_nodes.push_back(embedding_node); + delete_nodes.insert(embedding_node); + auto embedding_out_name = + pattern.node_reprs["embedding_out" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME( + embedding_out_name, embedding_out_node, pattern) + delete_nodes.insert(embedding_out_node); + if (i - 1 >= 0) { + auto ewadd_name = + pattern.node_reprs[string::Sprintf("ewadd%d%d", i - 1, i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(ewadd_name, ewadd_node, pattern) + delete_nodes.insert(ewadd_node); + auto ewadd_out_name = + pattern.node_reprs[string::Sprintf("ewadd%d%d_out", i - 1, i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME( + ewadd_out_name, ewadd_out_node, pattern) + if (i != n_embedding - 1) { + delete_nodes.insert(ewadd_out_node); + } + } + } + // Generate embedding_with_eltwise_add_xpu op + framework::OpDesc embedding_with_eltwise_add_xpu_op_desc; + embedding_with_eltwise_add_xpu_op_desc.SetType( + "embedding_with_eltwise_add_xpu"); + embedding_with_eltwise_add_xpu_op_desc.SetInput("ids", x_names); + embedding_with_eltwise_add_xpu_op_desc.SetInput("tables", table_names); + embedding_with_eltwise_add_xpu_op_desc.SetOutput("out", + {output_node->Name()}); + embedding_with_eltwise_add_xpu_op_desc.SetAttr("n_embedding", n_embedding); + int64_t padding_idx = PADDLE_GET_CONST( + int64_t, embedding_nodes[0]->Op()->GetAttr("padding_idx")); + if (GetBoolFromEnv("XPU_PADDING_IDX", true)) { + padding_idx = -1; + } + embedding_with_eltwise_add_xpu_op_desc.SetAttr( + "padding_idx", static_cast(padding_idx)); + auto* embedding_with_eltwise_add_xpu_op = + graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc); + for (size_t i = 0; i < x_nodes.size(); i++) { + SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op); + SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op); + } + SAFE_IR_NODE_LINK_TO(embedding_with_eltwise_add_xpu_op, output_node); + // delete useless node + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(embedding_with_eltwise_add_xpu_fuse_pass, + paddle::framework::ir::EmbeddingWithEltwiseAddXPUFusePass); + +REGISTER_PASS_CAPABILITY(embedding_with_eltwise_add_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "embedding_with_eltwise_add_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 8f193fc8203f8..0b701a452f6c0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -520,7 +520,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "generate_sequence_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", - // "embedding_with_eltwise_add_xpu_fuse_pass", + "embedding_with_eltwise_add_xpu_fuse_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", }); diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 7ba94a9f3da7d..ac80200a0f83f 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -1,3 +1,12 @@ +- op : embedding_with_eltwise_add_xpu + args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) + output: Tensor + infer_meta : + func: EmbeddingWithEltwiseAddXPUInferMeta + kernel: + func: embedding_with_eltwise_add_xpu + data_type: tables + - op : fc_xpu args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) output : Tensor(out), Tensor(out_max) diff --git a/paddle/phi/backends/xpu/xpu1_op_list.cc b/paddle/phi/backends/xpu/xpu1_op_list.cc index f40daae0c5dd5..e4ed23abf3a55 100644 --- a/paddle/phi/backends/xpu/xpu1_op_list.cc +++ b/paddle/phi/backends/xpu/xpu1_op_list.cc @@ -80,6 +80,8 @@ XPUOpMap& get_kl1_ops() { {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"embedding_with_eltwise_add_xpu", + XPUKernelSet({phi::DataType::FLOAT32})}, {"equal", XPUKernelSet({phi::DataType::INT64})}, {"expand_as_v2", XPUKernelSet({phi::DataType::INT32, diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9690f22dcbbbc..9f39cd8eaa540 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -209,6 +209,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT64, phi::DataType::INT32})}, + {"embedding_with_eltwise_add_xpu", + XPUKernelSet({phi::DataType::FLOAT32})}, {"empty", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index d468b13e17d2a..feb95cb32a90a 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -21,6 +21,28 @@ limitations under the License. */ namespace phi { +void EmbeddingWithEltwiseAddXPUInferMeta( + const std::vector& ids, + const std::vector& tables, + MetaTensor* out) { + PADDLE_ENFORCE_GT(ids.size(), + 0UL, + phi::errors::InvalidArgument( + "The input ids in EmbeddingWithEltwiseAddXPUInferMeta " + "can't be empty.")); + PADDLE_ENFORCE_GT(tables.size(), + 0UL, + phi::errors::InvalidArgument( + "The input tables in " + "EmbeddingWithEltwiseAddXPUInferMeta can't be empty.")); + + auto id_dims = ids[0]->dims(); + auto table_dims = tables[0]->dims(); + out->set_dims(phi::make_ddim({id_dims[0], id_dims[1], table_dims[1]})); + out->set_dtype(tables[0]->dtype()); + out->set_layout(ids[0]->layout()); +} + void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& x_max, const MetaTensor& w, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index e1fe0c3c112a9..6cba0552b1abc 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -22,6 +22,11 @@ namespace phi { // Common InferMeta Functions for fusion operators. // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. +void EmbeddingWithEltwiseAddXPUInferMeta( + const std::vector& ids, + const std::vector& tables, + MetaTensor* out); + void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& x_max, const MetaTensor& w, diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc new file mode 100644 index 0000000000000..afde2f8f3503b --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void EmbeddingWithEltwiseAddXpuKernel( + const Context& ctx, + const std::vector& ids, + const std::vector& tables, + int64_t padding_idx, + DenseTensor* out) { + auto& id_dims = ids[0]->dims(); + int idx_len = id_dims[0] * id_dims[1]; + int emb_layer_num = ids.size(); + int embed_dim = tables[0]->dims()[1]; + std::vector table_lens_cpu; + std::vector arg_tables; + for (auto* table : tables) { + auto& table_dims = table->dims(); + PADDLE_ENFORCE_EQ( + table_dims.size(), + 2, + errors::InvalidArgument( + "The table_dims size [%d] should be equal 2.", + table_dims.size())); /* shape like [table_len, embed_dim] */ + PADDLE_ENFORCE_EQ( + table_dims[1], + embed_dim, + errors::InvalidArgument( + "Every embed_dim [%d] should be equal the first one [%d].", + table_dims[1], + embed_dim)); + table_lens_cpu.push_back(table_dims[0]); + arg_tables.push_back(table->data()); + } + std::vector> int_idx(emb_layer_num, + std::vector(idx_len, 0)); + std::vector> arg_ids; + for (int i = 0; i < emb_layer_num; i++) { + for (int j = 0; j < idx_len; j++) { + int_idx[i][j] = static_cast(ids[i]->data()[j]); + } + arg_ids.push_back( + xpu::VectorParam{int_idx[i].data(), idx_len, nullptr}); + } + ctx.template Alloc(out); + int r = xpu::multi_embedding_fusion( + ctx.x_context(), + arg_tables, /* tables */ + out->data(), + arg_ids, + table_lens_cpu, + embed_dim, + std::vector(table_lens_cpu.size(), 1.0f), + std::vector(table_lens_cpu.size(), padding_idx)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::EmbeddingWithEltwiseAddXpuKernel, + float) { + kernel->InputAt(0).SetBackend(phi::Backend::CPU); +} diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py new file mode 100644 index 0000000000000..e4d545934136a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestEmbeddingWithEltwiseAddXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["embedding_with_eltwise_add_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + + # lookup_table_v2 + lookup_table_num = draw(st.sampled_from([2, 3, 4])) + print("lookup_table_num: ", lookup_table_num) + ids_shape = draw(st.sampled_from([[1, 32]])) + w_shape = draw(st.sampled_from([[1000, 32]])) + padding_idx = draw(st.sampled_from([-1])) + axis = draw(st.sampled_from([-1])) + + def gen_lookup_table_ops(): + lookup_table_op_config_list = [] + lookup_table_op_0 = OpConfig( + "lookup_table_v2", + inputs={ + "Ids": ["lookup_table_ids_0"], + "W": ["lookup_table_w_0"], + }, + outputs={"Out": ["lookup_table_out_0"]}, + padding_idx=padding_idx, + ) + lookup_table_op_1 = OpConfig( + "lookup_table_v2", + inputs={ + "Ids": ["lookup_table_ids_1"], + "W": ["lookup_table_w_1"], + }, + outputs={"Out": ["lookup_table_out_1"]}, + padding_idx=padding_idx, + ) + lookup_table_ops_list = [lookup_table_op_0, lookup_table_op_1] + if lookup_table_num >= 3: + lookup_table_op_2 = OpConfig( + "lookup_table_v2", + inputs={ + "Ids": ["lookup_table_ids_2"], + "W": ["lookup_table_w_2"], + }, + outputs={"Out": ["lookup_table_out_2"]}, + padding_idx=padding_idx, + ) + lookup_table_ops_list.append(lookup_table_op_2) + if lookup_table_num >= 4: + lookup_table_op_3 = OpConfig( + "lookup_table_v2", + inputs={ + "Ids": ["lookup_table_ids_3"], + "W": ["lookup_table_w_3"], + }, + outputs={"Out": ["lookup_table_out_3"]}, + padding_idx=padding_idx, + ) + lookup_table_ops_list.append(lookup_table_op_3) + return lookup_table_ops_list + + add_op_num = lookup_table_num - 1 + + def gen_eltwise_add_ops(): + add_op_0 = OpConfig( + "elementwise_add", + inputs={ + "X": ["lookup_table_out_0"], + "Y": ["lookup_table_out_1"], + }, + outputs={"Out": ["add_op_0_out"]}, + axis=axis, + ) + add_op_list = [add_op_0] + if add_op_num >= 2: + add_op_1 = OpConfig( + "elementwise_add", + inputs={"X": ["add_op_0_out"], "Y": ["lookup_table_out_2"]}, + outputs={"Out": ["add_op_1_out"]}, + axis=axis, + ) + add_op_list.append(add_op_1) + + if add_op_num >= 3: + add_op_2 = OpConfig( + "elementwise_add", + inputs={"X": ["add_op_1_out"], "Y": ["lookup_table_out_3"]}, + outputs={"Out": ["add_op_2_out"]}, + axis=axis, + ) + add_op_list.append(add_op_2) + return add_op_list + + lookup_table_op_list = gen_lookup_table_ops() + add_op_list = gen_eltwise_add_ops() + + # ops + ops = [] + ops.extend(lookup_table_op_list) + ops.extend(add_op_list) + + # inputs + def generate_input(*args, **kwargs): + return np.random.randint(0, w_shape[0], ids_shape).astype(np.int64) + + def gen_lookup_table_inputs_data(*args, **kwargs): + inputs = {} + for i in range(lookup_table_num): + input_name = "lookup_table_ids_{}".format(i) + inputs[input_name] = TensorConfig( + data_gen=partial(generate_input) + ) + return inputs + + inputs = gen_lookup_table_inputs_data() + + # weights + def gen_lookup_table_weights_data(): + weights = {} + for i in range(lookup_table_num): + w_name = "lookup_table_w_{}".format(i) + weights[w_name] = TensorConfig(shape=w_shape) + return weights + + weights = gen_lookup_table_weights_data() + + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs=inputs, + outputs=add_op_list[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=3, + min_success_num=3, + passes=["embedding_with_eltwise_add_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()