Skip to content

Commit

Permalink
refine cpp codes
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Oct 16, 2021
2 parents eec1fc6 + adb8049 commit dc103de
Show file tree
Hide file tree
Showing 44 changed files with 2,267 additions and 461 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
paddle_to_cinn_pass fix_op_run_order_pass)
fix_op_run_order_pass build_cinn_pass)
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {

// Note: This pass is used to enable cinn.
if (FLAGS_use_cinn) {
AppendPass("paddle_to_cinn_pass");
AppendPass("build_cinn_pass");
}
SetCollectiveContext();
}
Expand Down Expand Up @@ -486,6 +486,7 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass);
USE_PASS(build_cinn_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)

pass_library(graph_to_program_pass base)
pass_library(paddle_to_cinn_pass base DEPS cinn_runner)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
Expand Down Expand Up @@ -144,7 +143,6 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(paddle_to_cinn_pass_test SRCS paddle_to_cinn_pass_test.cc DEPS paddle_to_cinn_pass proto_desc)
cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr->GeneratedOp());

// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
if (right_generated_op->Name() != "conv2d_grad" &&
right_generated_op->Name() != "resnet_unit_grad") {
continue;
}

Expand Down Expand Up @@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" &&
op_desc->Type() != "resnet_unit_grad")) {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X";
auto iter = outputs.find(GradVarName(grad_var_name));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
Expand Down
31 changes: 0 additions & 31 deletions paddle/fluid/framework/ir/paddle_to_cinn_pass.cc

This file was deleted.

30 changes: 0 additions & 30 deletions paddle/fluid/framework/ir/paddle_to_cinn_pass.h

This file was deleted.

40 changes: 0 additions & 40 deletions paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc

This file was deleted.

2 changes: 2 additions & 0 deletions paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_method graph lod_tensor proto_desc)
cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector)

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc)
cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object)
cc_test(test_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS build_cinn_pass)
Loading

0 comments on commit dc103de

Please sign in to comment.