Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 committed Feb 22, 2023
1 parent 76c8ae1 commit cf647d7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
29 changes: 15 additions & 14 deletions paddle/fluid/framework/ir/delete_dropout_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,29 @@ void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);

// link dropout_op_out to pre_op
// link dropout_op_x to next_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
auto next_op_nodes = dropout_op_out->outputs;
for (auto next_op_node : next_op_nodes) {
auto next_op_desc = next_op_node->Op();
auto next_op_inputs = next_op_desc->Inputs();
for (auto& input_var : next_op_inputs) {
auto names = input_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_out_name) {
names[i] = dropout_op_x_name;
next_op_desc->SetInput(input_var.first, names);
break;
}
}
}
IR_NODE_LINK_TO(dropout_op_x, next_op_node);
}
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);

// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
dropout_op, dropout_op_mask, dropout_op_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void EmbeddingWithEltwiseAddXPUFusePass::ApplyImpl(
"padding_idx", static_cast<int64_t>(padding_idx));
auto* embedding_with_eltwise_add_xpu_op =
graph->CreateOpNode(&embedding_with_eltwise_add_xpu_op_desc);
for (int i = 0; i < x_nodes.size(); i++) {
for (size_t i = 0; i < x_nodes.size(); i++) {
SAFE_IR_NODE_LINK_TO(x_nodes[i], embedding_with_eltwise_add_xpu_op);
SAFE_IR_NODE_LINK_TO(table_nodes[i], embedding_with_eltwise_add_xpu_op);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void EmbeddingWithEltwiseAddXpuKernel(
std::vector<std::vector<int>> int_idx(emb_layer_num,
std::vector<int>(idx_len, 0));
std::vector<xpu::VectorParam<int>> arg_ids;
for (size_t i = 0; i < emb_layer_num; i++) {
for (size_t j = 0; j < idx_len; j++) {
for (int i = 0; i < emb_layer_num; i++) {
for (int j = 0; j < idx_len; j++) {
int_idx[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
}
arg_ids.push_back(
Expand Down

0 comments on commit cf647d7

Please sign in to comment.