From cf647d734449a99d31ac11a7133f39df68d5eecf Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 22 Feb 2023 04:05:50 +0000 Subject: [PATCH] update --- .../framework/ir/delete_dropout_op_pass.cc | 29 ++++++++++--------- ...mbedding_with_eltwise_add_xpu_fuse_pass.cc | 2 +- .../embedding_with_eltwise_add_xpu_kernel.cc | 4 +-- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc index e3c2e6cef2114..b1765440159da 100644 --- a/paddle/fluid/framework/ir/delete_dropout_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_dropout_op_pass.cc @@ -43,28 +43,29 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE(dropout_op_out); GET_IR_NODE(dropout_op_mask); - // link dropout_op_out to pre_op + // link dropout_op_x to next_op auto dropout_op_x_name = dropout_op_x->Var()->Name(); auto dropout_op_out_name = dropout_op_out->Var()->Name(); - auto pre_ops = dropout_op_x->inputs; - if (pre_ops.empty()) return; - auto pre_op_desc = pre_ops[0]->Op(); - auto pre_op_outs = pre_op_desc->Outputs(); - for (auto& out_var : pre_op_outs) { - auto names = out_var.second; - for (size_t i = 0; i < names.size(); i++) { - if (names[i] == dropout_op_x_name) { - names[i] = dropout_op_out_name; - pre_op_desc->SetOutput(out_var.first, names); - break; + auto next_op_nodes = dropout_op_out->outputs; + for (auto next_op_node : next_op_nodes) { + auto next_op_desc = next_op_node->Op(); + auto next_op_inputs = next_op_desc->Inputs(); + for (auto& input_var : next_op_inputs) { + auto names = input_var.second; + for (size_t i = 0; i < names.size(); i++) { + if (names[i] == dropout_op_out_name) { + names[i] = dropout_op_x_name; + next_op_desc->SetInput(input_var.first, names); + break; + } } } + IR_NODE_LINK_TO(dropout_op_x, next_op_node); } - IR_NODE_LINK_TO(pre_ops[0], dropout_op_out); // delete useless node std::unordered_set delete_nodes{ - dropout_op_x, dropout_op, dropout_op_mask}; + dropout_op, dropout_op_mask, dropout_op_out}; GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; }; diff --git a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc index 8cd2c528b10d0..05975b6a1c24c 100644 --- a/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc @@ -287,7 +287,7 @@ void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl( "padding_idx", static_cast(padding_idx)); auto* embedding_with_eltwise_add_xpu_op = graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc); - for (int i = 0; i < x_nodes.size(); i++) { + for (size_t i = 0; i < x_nodes.size(); i++) { SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op); SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op); } diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc index f41dde931130e..afde2f8f3503b 100644 --- a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -52,8 +52,8 @@ void EmbeddingWithEltwiseAddXpuKernel( std::vector> int_idx(emb_layer_num, std::vector(idx_len, 0)); std::vector> arg_ids; - for (size_t i = 0; i < emb_layer_num; i++) { - for (size_t j = 0; j < idx_len; j++) { + for (int i = 0; i < emb_layer_num; i++) { + for (int j = 0; j < idx_len; j++) { int_idx[i][j] = static_cast(ids[i]->data()[j]); } arg_ids.push_back(