From b44ac82ca5c0404eebb943bb6ab5d97ee196ce2e Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 15 Feb 2023 06:49:16 +0000 Subject: [PATCH 1/4] migrate pass1 --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + ...mbedding_with_eltwise_add_xpu_fuse_pass.cc | 314 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 2 +- paddle/phi/api/yaml/static_ops.yaml | 9 + paddle/phi/backends/xpu/xpu1_op_list.cc | 2 + paddle/phi/backends/xpu/xpu2_op_list.cc | 2 + paddle/phi/infermeta/fusion.cc | 22 ++ paddle/phi/infermeta/fusion.h | 5 + .../embedding_with_eltwise_add_xpu_kernel.cc | 81 +++++ 9 files changed, 437 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 067bc82a0189a..77e395c64eceb 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -217,6 +217,7 @@ if(WITH_XPU) SRCS xpu/quant_utils.cc DEPS pass) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS quant_utils) + pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc new file mode 100644 index 0000000000000..c7929c72aaffc --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc @@ -0,0 +1,314 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +static bool GetBoolFromEnv(const std::string& str, bool def = false) { + char* variable = std::getenv(str.c_str()); + if (!variable) { + return def; + } + if (strcmp(variable, "false") == 0 || strcmp(variable, "0") == 0) { + return false; + } else { + return true; + } +} + +namespace patterns { + +struct EmbeddingWithEltwiseAddXPUPattern : public PatternBase { + EmbeddingWithEltwiseAddXPUPattern(PDPattern* pattern, + const std::string& name_scope, + int n_embedding_, + const std::string& op_type, + const std::string& pre_op_type); + + // declare operator node's name + PATTERN_DECL_NODE(embedding0); + PATTERN_DECL_NODE(embedding1); + PATTERN_DECL_NODE(ewadd01); + // declare variable node's name + PATTERN_DECL_NODE(x0); + PATTERN_DECL_NODE(x1); + PATTERN_DECL_NODE(table0); + PATTERN_DECL_NODE(table1); + PATTERN_DECL_NODE(embedding_out0); + PATTERN_DECL_NODE(embedding_out1); + PATTERN_DECL_NODE(ewadd01_out); + + std::unordered_map node_reprs; + + private: + int n_embedding_; + std::string op_type_; + std::string pre_op_type_; +}; + +EmbeddingWithEltwiseAddXPUPattern::EmbeddingWithEltwiseAddXPUPattern( + PDPattern* pattern, + const std::string& name_scope, + int n_embedding, + const std::string& op_type, + const std::string& pre_op_type) + : PatternBase(pattern, name_scope, name_scope), + n_embedding_(n_embedding), + op_type_(op_type), + pre_op_type_(pre_op_type) { + for (int i = 0; i < n_embedding; i++) { + node_reprs["x" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "x" + std::to_string(i)); + node_reprs["table" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "table" + std::to_string(i)); + node_reprs["embedding" + std::to_string(i)] = + PDNodeName(name_scope_, repr_, id_, "embedding" + std::to_string(i)); + node_reprs["embedding_out" + std::to_string(i)] = PDNodeName( + name_scope_, repr_, id_, "embedding_out" + std::to_string(i)); + if (i - 1 >= 0) { + auto ewadd_name = string::Sprintf("ewadd%d%d", i - 1, i); + node_reprs[ewadd_name] = PDNodeName(name_scope_, repr_, id_, ewadd_name); + auto ewadd_out_name = string::Sprintf("ewadd%d%d_out", i - 1, i); + node_reprs[ewadd_out_name] = + PDNodeName(name_scope_, repr_, id_, ewadd_out_name); + } + } + PDNode* x0 = pattern->NewNode(x0_repr()) + ->assert_is_op_input(op_type_, "Ids") + ->assert_var_not_persistable() + ->AsInput(); + PDNode* x1 = pattern->NewNode(x1_repr()) + ->assert_is_op_input(op_type_, "Ids") + ->assert_var_not_persistable() + ->AsInput(); + PDNode* embedding0 = + pattern->NewNode(embedding0_repr())->assert_is_op(op_type_); + auto* table0 = pattern->NewNode(table0_repr()) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* embedding_out0 = pattern->NewNode(embedding_out0_repr()) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "X"); + auto* table1 = pattern->NewNode(table1_repr()) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* embedding1 = + pattern->NewNode(embedding1_repr())->assert_is_op(op_type_); + + auto* embedding_out1 = pattern->NewNode(embedding_out1_repr()) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "Y"); + auto* ewadd01 = + pattern->NewNode(ewadd01_repr())->assert_is_op("elementwise_add"); + auto* ewadd01_out = pattern->NewNode(ewadd01_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + embedding0->LinksFrom({x0, table0}); + embedding1->LinksFrom({x1, table1}); + embedding0->LinksTo({embedding_out0}); + embedding1->LinksTo({embedding_out1}); + ewadd01->LinksFrom({embedding_out0, embedding_out1}); + ewadd01->LinksTo({ewadd01_out}); + + auto* last_ewadd_out = ewadd01_out; + for (int i = 2; i < n_embedding; ++i) { + auto x_name = node_reprs["x" + std::to_string(i)]; + auto table_name = node_reprs["table" + std::to_string(i)]; + auto embedding_name = node_reprs["embedding" + std::to_string(i)]; + auto embedding_out_name = node_reprs["embedding_out" + std::to_string(i)]; + auto* new_table = pattern->NewNode(table_name) + ->assert_is_op_input(op_type_, "W") + ->AsInput(); + auto* new_embedding = + pattern->NewNode(embedding_name)->assert_is_op(op_type_); + auto* new_embedding_out = pattern->NewNode(embedding_out_name) + ->assert_is_op_output(op_type_, "Out") + ->assert_is_op_input("elementwise_add", "Y"); + auto* new_x = pattern->NewNode(x_name) + ->assert_is_op_input(op_type_, "Ids") + ->AsInput(); + new_embedding->LinksFrom({new_x, new_table}); + new_embedding->LinksTo({new_embedding_out}); + auto ewadd_name = + node_reprs["ewadd" + std::to_string(i - 1) + std::to_string(i)]; + auto ewadd_out_name = node_reprs["ewadd" + std::to_string(i - 1) + + std::to_string(i) + "_out"]; + auto* new_ewadd = + pattern->NewNode(ewadd_name)->assert_is_op("elementwise_add"); + auto* new_ewadd_out = pattern->NewNode(ewadd_out_name) + ->assert_is_op_output("elementwise_add", "Out"); + new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out}); + new_ewadd->LinksTo({new_ewadd_out}); + last_ewadd_out = new_ewadd_out; + } + last_ewadd_out->AsOutput(); +} + +} // namespace patterns + +class EmbeddingWithEltwiseAddXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void ApplyImpl(ir::Graph* graph, + int n_embedding, + const std::string op_type, + const std::string pre_op_type) const; + + const std::string name_scope_{"embedding_with_eltwise_add_xpu_fuse_pass"}; +}; + +void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init(name_scope_, graph); + std::vector pre_op_types{"reshape2", "squeeze2", ""}; + std::vector op_types{"lookup_table", "lookup_table_v2"}; + for (auto& pre_op_type : pre_op_types) { + for (int n_embedding : {4, 3, 2}) { + for (auto& op_type : op_types) { + ApplyImpl(graph, n_embedding, op_type, pre_op_type); + } + } + } +} + +void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl( + ir::Graph* graph, + int n_embedding, + const std::string op_type, + const std::string pre_op_type) const { + GraphPatternDetector gpd; + patterns::EmbeddingWithEltwiseAddXPUPattern pattern( + gpd.mutable_pattern(), name_scope_, n_embedding, op_type, pre_op_type); + int found_subgraph_count = 0; +#define GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(name, rt_node, pat) \ + PADDLE_ENFORCE_NE( \ + subgraph.count(pat.PatternBase::pattern->RetrieveNode(name)), \ + 0UL, \ + platform::errors::NotFound("Node not found for PDNode %s", name)); \ + Node* rt_node = subgraph.at(pat.PatternBase::pattern->RetrieveNode(name)); \ + PADDLE_ENFORCE_NOT_NULL( \ + rt_node, \ + platform::errors::NotFound("node %s not exists in the sub-graph", \ + #rt_node)); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + std::vector x_names; + std::vector table_names; + std::vector x_nodes; + std::vector table_nodes; + std::vector embedding_nodes; + auto output_name = pattern.node_reprs[string::Sprintf( + "ewadd%d%d_out", n_embedding - 2, n_embedding - 1)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(output_name, output_node, pattern) + std::unordered_set delete_nodes; + for (int i = 0; i < n_embedding; ++i) { + // Ids + auto x_name = pattern.node_reprs["x" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(x_name, x_node, pattern) + x_nodes.push_back(x_node); + x_names.push_back(x_node->Name()); + // Tables + auto table_name = pattern.node_reprs["table" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(table_name, table_node, pattern) + table_nodes.push_back(table_node); + table_names.push_back(table_node->Name()); + // Embedding + auto embedding_name = pattern.node_reprs["embedding" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(embedding_name, embedding_node, pattern) + embedding_nodes.push_back(embedding_node); + delete_nodes.insert(embedding_node); + auto embedding_out_name = + pattern.node_reprs["embedding_out" + std::to_string(i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME( + embedding_out_name, embedding_out_node, pattern) + delete_nodes.insert(embedding_out_node); + if (i - 1 >= 0) { + auto ewadd_name = + pattern.node_reprs[string::Sprintf("ewadd%d%d", i - 1, i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(ewadd_name, ewadd_node, pattern) + delete_nodes.insert(ewadd_node); + auto ewadd_out_name = + pattern.node_reprs[string::Sprintf("ewadd%d%d_out", i - 1, i)]; + GET_IR_NODE_FROM_SUBGRAPH_BY_NAME( + ewadd_out_name, ewadd_out_node, pattern) + if (i != n_embedding - 1) { + delete_nodes.insert(ewadd_out_node); + } + } + } + // Generate embedding_with_eltwise_add_xpu op + auto* scope = param_scope(); + framework::OpDesc embedding_with_eltwise_add_xpu_op_desc; + embedding_with_eltwise_add_xpu_op_desc.SetType( + "embedding_with_eltwise_add_xpu"); + embedding_with_eltwise_add_xpu_op_desc.SetInput("ids", x_names); + embedding_with_eltwise_add_xpu_op_desc.SetInput("tables", table_names); + embedding_with_eltwise_add_xpu_op_desc.SetOutput("out", + {output_node->Name()}); + embedding_with_eltwise_add_xpu_op_desc.SetAttr("n_embedding", n_embedding); + int64_t padding_idx = PADDLE_GET_CONST( + int64_t, embedding_nodes[0]->Op()->GetAttr("padding_idx")); + if (GetBoolFromEnv("XPU_PADDING_IDX", true)) { + padding_idx = -1; + } + embedding_with_eltwise_add_xpu_op_desc.SetAttr( + "padding_idx", static_cast(padding_idx)); + auto* embedding_with_eltwise_add_xpu_op = + graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc); + for (int i = 0; i < x_nodes.size(); i++) { + SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op); + SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op); + } + SAFE_IR_NODE_LINK_TO(embedding_with_eltwise_add_xpu_op, output_node); + // delete useless node + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(embedding_with_eltwise_add_xpu_fuse_pass, + paddle::framework::ir::EmbeddingWithEltwiseAddXPUFusePass); + +REGISTER_PASS_CAPABILITY(embedding_with_eltwise_add_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "embedding_with_eltwise_add_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5107fe2ace0f6..db99283e9f84c 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -518,7 +518,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ "delete_dropout_op_pass", // "multi_encoder_xpu_fuse_pass", - // "embedding_with_eltwise_add_xpu_fuse_pass", + "embedding_with_eltwise_add_xpu_fuse_pass", "fc_xpu_fuse_pass", // "multi_encoder_slice_link_xpu_fuse_pass", // "generate_sequence_xpu_fuse_pass", diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index eb918ce5b1081..44f76b8c516a1 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -1,3 +1,12 @@ +- op : embedding_with_eltwise_add_xpu + args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) + output: Tensor + infer_meta : + func: EmbeddingWithEltwiseAddXPUInferMeta + kernel: + func: embedding_with_eltwise_add_xpu + data_type: tables + - op : fc_xpu args : (Tensor x, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) output : Tensor diff --git a/paddle/phi/backends/xpu/xpu1_op_list.cc b/paddle/phi/backends/xpu/xpu1_op_list.cc index 6b8f9b47011e1..93a0036078e8c 100644 --- a/paddle/phi/backends/xpu/xpu1_op_list.cc +++ b/paddle/phi/backends/xpu/xpu1_op_list.cc @@ -80,6 +80,8 @@ XPUOpMap& get_kl1_ops() { {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"embedding_with_eltwise_add_xpu", + XPUKernelSet({phi::DataType::FLOAT32})}, {"equal", XPUKernelSet({phi::DataType::INT64})}, {"expand_as_v2", XPUKernelSet({phi::DataType::INT32, diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index b780a25e33b19..ac38d9168229b 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -198,6 +198,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT64, phi::DataType::INT32})}, + {"embedding_with_eltwise_add_xpu", + XPUKernelSet({phi::DataType::FLOAT32})}, {"empty", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index f7188cdf77e53..b6c449bb0b68e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -42,4 +42,26 @@ void FcXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void EmbeddingWithEltwiseAddXPUInferMeta( + const std::vector& ids, + const std::vector& tables, + MetaTensor* out) { + PADDLE_ENFORCE_GT(ids.size(), + 0UL, + phi::errors::InvalidArgument( + "The input ids in EmbeddingWithEltwiseAddXPUInferMeta " + "can't be empty.")); + PADDLE_ENFORCE_GT(tables.size(), + 0UL, + phi::errors::InvalidArgument( + "The input tables in " + "EmbeddingWithEltwiseAddXPUInferMeta can't be empty.")); + + auto id_dims = ids[0]->dims(); + auto table_dims = tables[0]->dims(); + out->set_dims(phi::make_ddim({id_dims[0], id_dims[1], table_dims[1]})); + out->set_dtype(tables[0]->dtype()); + out->set_layout(ids[0]->layout()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index ba60dacb34020..0a47923c27090 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -34,4 +34,9 @@ void FcXPUInferMeta(const MetaTensor& x, float act_alpha, MetaTensor* out); +void EmbeddingWithEltwiseAddXPUInferMeta( + const std::vector& ids, + const std::vector& tables, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc new file mode 100644 index 0000000000000..3ffee846f27fb --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void EmbeddingWithEltwiseAddXpuKernel( + const Context& ctx, + const std::vector& ids, + const std::vector& tables, + int64_t padding_idx, + DenseTensor* out) { + auto& id_dims = ids[0]->dims(); + int idx_len = id_dims[0] * id_dims[1]; + int emb_layer_num = ids.size(); + int embed_dim = tables[0]->dims()[1]; + std::vector table_lens_cpu; + std::vector arg_tables; + for (auto* table : tables) { + auto& table_dims = table->dims(); + PADDLE_ENFORCE_EQ( + table_dims.size(), + 2, + errors::InvalidArgument( + "The table_dims size [%d] should be equal 2.", + table_dims.size())); /* shape like [table_len, embed_dim] */ + PADDLE_ENFORCE_EQ( + table_dims[1], + embed_dim, + errors::InvalidArgument( + "Every embed_dim [%d] should be equal the first one [%d].", + table_dims[1], + embed_dim)); + table_lens_cpu.push_back(table_dims[0]); + arg_tables.push_back(table->data()); + } + std::vector> int_idx(emb_layer_num, + std::vector(idx_len, 0)); + std::vector> arg_ids; + for (size_t i = 0; i < emb_layer_num; i++) { + arg_ids.push_back( + xpu::VectorParam{int_idx[i].data(), idx_len, nullptr}); + } + ctx.template Alloc(out); + int r = xpu::multi_embedding_fusion( + ctx.x_context(), + arg_tables, /* tables */ + out->data(), + arg_ids, + table_lens_cpu, + embed_dim, + std::vector(table_lens_cpu.size(), 1.0f), + std::vector(table_lens_cpu.size(), padding_idx)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::EmbeddingWithEltwiseAddXpuKernel, + float) { + kernel->InputAt(0).SetBackend(phi::Backend::CPU); +} From 2c0f308e82b5c1f84d6a82fa2ca009ec30afebb0 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Tue, 21 Feb 2023 02:59:26 +0000 Subject: [PATCH 2/4] update --- .../framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc index c7929c72aaffc..8cd2c528b10d0 100644 --- a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc @@ -270,7 +270,6 @@ void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl( } } // Generate embedding_with_eltwise_add_xpu op - auto* scope = param_scope(); framework::OpDesc embedding_with_eltwise_add_xpu_op_desc; embedding_with_eltwise_add_xpu_op_desc.SetType( "embedding_with_eltwise_add_xpu"); From cf647d734449a99d31ac11a7133f39df68d5eecf Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 22 Feb 2023 04:05:50 +0000 Subject: [PATCH 3/4] update --- .../framework/ir/delete_dropout_op_pass.cc | 29 ++++++++++--------- ...mbedding_with_eltwise_add_xpu_fuse_pass.cc | 2 +- .../embedding_with_eltwise_add_xpu_kernel.cc | 4 +-- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc index e3c2e6cef2114..b1765440159da 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -43,28 +43,29 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE(dropout_op_out); GET_IR_NODE(dropout_op_mask); - // link dropout_op_out to pre_op + // link dropout_op_x to next_op auto dropout_op_x_name = dropout_op_x->Var()->Name(); auto dropout_op_out_name = dropout_op_out->Var()->Name(); - auto pre_ops = dropout_op_x->inputs; - if (pre_ops.empty()) return; - auto pre_op_desc = pre_ops[0]->Op(); - auto pre_op_outs = pre_op_desc->Outputs(); - for (auto& out_var : pre_op_outs) { - auto names = out_var.second; - for (size_t i = 0; i < names.size(); i++) { - if (names[i] == dropout_op_x_name) { - names[i] = dropout_op_out_name; - pre_op_desc->SetOutput(out_var.first, names); - break; + auto next_op_nodes = dropout_op_out->outputs; + for (auto next_op_node : next_op_nodes) { + auto next_op_desc = next_op_node->Op(); + auto next_op_inputs = next_op_desc->Inputs(); + for (auto& input_var : next_op_inputs) { + auto names = input_var.second; + for (size_t i = 0; i < names.size(); i++) { + if (names[i] == dropout_op_out_name) { + names[i] = dropout_op_x_name; + next_op_desc->SetInput(input_var.first, names); + break; + } } } + IR_NODE_LINK_TO(dropout_op_x, next_op_node); } - IR_NODE_LINK_TO(pre_ops[0], dropout_op_out); // delete useless node std::unordered_set delete_nodes{ - dropout_op_x, dropout_op, dropout_op_mask}; + dropout_op, dropout_op_mask, dropout_op_out}; GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; }; diff --git a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc index 8cd2c528b10d0..05975b6a1c24c 100644 --- a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc @@ -287,7 +287,7 @@ void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl( "padding_idx", static_cast(padding_idx)); auto* embedding_with_eltwise_add_xpu_op = graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc); - for (int i = 0; i < x_nodes.size(); i++) { + for (size_t i = 0; i < x_nodes.size(); i++) { SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op); SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op); } diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc index f41dde931130e..afde2f8f3503b 100644 --- a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -52,8 +52,8 @@ void EmbeddingWithEltwiseAddXpuKernel( std::vector> int_idx(emb_layer_num, std::vector(idx_len, 0)); std::vector> arg_ids; - for (size_t i = 0; i < emb_layer_num; i++) { - for (size_t j = 0; j < idx_len; j++) { + for (int i = 0; i < emb_layer_num; i++) { + for (int j = 0; j < idx_len; j++) { int_idx[i][j] = static_cast(ids[i]->data()[j]); } arg_ids.push_back( From 52a1a11983183fc87c92b51749e9be2ad49bc2de Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 22 Feb 2023 07:58:43 +0000 Subject: [PATCH 4/4] fix dropout miss mask output --- .../framework/ir/delete_dropout_op_pass.cc | 73 ++++++++++--------- .../framework/ir/graph_pattern_detector.cc | 14 ++-- .../framework/ir/graph_pattern_detector.h | 2 +- 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc index b1765440159da..285c25c6a5e9d 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -30,47 +30,50 @@ namespace ir { void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "delete_dropout_op_pattern"; FusePassBase::Init(pattern_name, graph); + int found_subgraph_count = 0; - GraphPatternDetector gpd; - patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name); - pattern(); + for (auto with_mask : {true, false}) { + GraphPatternDetector gpd; + patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), + pattern_name); + pattern(with_mask); - int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE(dropout_op_x); - GET_IR_NODE(dropout_op); - GET_IR_NODE(dropout_op_out); - GET_IR_NODE(dropout_op_mask); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE(dropout_op_x); + GET_IR_NODE(dropout_op); + GET_IR_NODE(dropout_op_out); - // link dropout_op_x to next_op - auto dropout_op_x_name = dropout_op_x->Var()->Name(); - auto dropout_op_out_name = dropout_op_out->Var()->Name(); - auto next_op_nodes = dropout_op_out->outputs; - for (auto next_op_node : next_op_nodes) { - auto next_op_desc = next_op_node->Op(); - auto next_op_inputs = next_op_desc->Inputs(); - for (auto& input_var : next_op_inputs) { - auto names = input_var.second; - for (size_t i = 0; i < names.size(); i++) { - if (names[i] == dropout_op_out_name) { - names[i] = dropout_op_x_name; - next_op_desc->SetInput(input_var.first, names); - break; + // link dropout_op_x to next_op + auto dropout_op_x_name = dropout_op_x->Var()->Name(); + auto dropout_op_out_name = dropout_op_out->Var()->Name(); + auto next_op_nodes = dropout_op_out->outputs; + for (auto next_op_node : next_op_nodes) { + auto next_op_desc = next_op_node->Op(); + auto next_op_inputs = next_op_desc->Inputs(); + for (auto& input_var : next_op_inputs) { + auto names = input_var.second; + for (size_t i = 0; i < names.size(); i++) { + if (names[i] == dropout_op_out_name) { + names[i] = dropout_op_x_name; + next_op_desc->SetInput(input_var.first, names); + break; + } } } + IR_NODE_LINK_TO(dropout_op_x, next_op_node); } - IR_NODE_LINK_TO(dropout_op_x, next_op_node); - } - - // delete useless node - std::unordered_set delete_nodes{ - dropout_op, dropout_op_mask, dropout_op_out}; - GraphSafeRemoveNodes(graph, delete_nodes); - found_subgraph_count++; - }; - - gpd(graph, handler); + // delete useless node + std::unordered_set delete_nodes{dropout_op, dropout_op_out}; + if (with_mask) { + GET_IR_NODE(dropout_op_mask); + delete_nodes.insert(dropout_op_mask); + } + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + gpd(graph, handler); + } AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4a1f80916ef2b..5e3f2dc42ce1b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()( return concat_out; } -void patterns::DeleteDropoutOpPattern::operator()() { +void patterns::DeleteDropoutOpPattern::operator()(bool with_mask) { auto dropout_op_x = pattern->NewNode(dropout_op_x_repr()) ->assert_is_op_input("dropout", "X") ->AsInput(); @@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() { std::string("upscale_in_train")); auto dropout_op_out = pattern->NewNode(dropout_op_out_repr()) ->assert_is_op_output("dropout", "Out"); - auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr()) - ->assert_is_op_output("dropout", "Mask"); - dropout_op->LinksFrom({dropout_op_x}) - .LinksTo({dropout_op_out, dropout_op_mask}); + if (with_mask) { + auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr()) + ->assert_is_op_output("dropout", "Mask"); + dropout_op->LinksFrom({dropout_op_x}) + .LinksTo({dropout_op_out, dropout_op_mask}); + } else { + dropout_op->LinksFrom({dropout_op_x}).LinksTo({dropout_op_out}); + } } void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f25260e0ce04e..80474d0d67acb 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase { DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {} - void operator()(); + void operator()(bool with_mask); PATTERN_DECL_NODE(dropout_op_x); PATTERN_DECL_NODE(dropout_op);