diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc index 92c3a542d135cd..e59ba8b4232932 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc @@ -18,6 +18,7 @@ #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule_block_graph.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" @@ -94,6 +95,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, auto all_loops = ir_schedule->GetLoops(block_name); CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; + CHECK_GE(num_loops_to_bind, 0) + << "The number of loops to be bind should be greater than 0"; // check whether it is the case that threadIdx has been binded but blockIdx // not, the threadIdx can only be binded in the first loop after // num_loops_to_bind loops because we has excluded other cases in @@ -101,6 +104,17 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, bool gpu_thread_has_binded = num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As()->is_gpu_thread_binded(); + ir::BlockOrderConstructor block_order_constructor; + std::map, ir::Expr> blocks_order_with_ctrl_stmt = + block_order_constructor(&all_loops[num_loops_to_bind - 1]); + for (auto& pair : blocks_order_with_ctrl_stmt) { + if (pair.first.size() == 2) { + ir::Expr stmt = pair.second; + if (stmt.As() && stmt.As()->is_gpu_thread_binded()) { + gpu_thread_has_binded = true; + } + } + } Expr fused_loop = ir_schedule->Fuse( {all_loops.begin(), all_loops.begin() + num_loops_to_bind}); int32_t extent = fused_loop.As()->extent.as_int32(); @@ -181,5 +195,18 @@ std::vector AutoBind::ApplyOnBlock(SearchState state, return {new_state}; } +void AutoBind::Apply(ir::IRSchedule* ir_schedule, + const std::string& block_name) { + int num_loop_can_bind = + CountLoopCanBinded(ir_schedule->GetLoops(block_name)[0].As()); + if (num_loop_can_bind > 0) { + BindGPUIndex(ir_schedule, + block_name, + num_loop_can_bind, + kMaxBlocks, + target_->max_num_threads()); + } +} + } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h index a45bd31d4b33ff..c4baf8e7797e38 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h @@ -42,6 +42,8 @@ class AutoBind : public AutoGenRule { std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + void Apply(ir::IRSchedule* ir_schedule, const std::string& block_name); + private: std::vector applicable_schedule_blocks_; }; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc index e8cab5dd63fa29..57e13c00a1c76b 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc @@ -28,6 +28,7 @@ #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" @@ -49,6 +50,11 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); // Check the schedule block to be inlined is not a reduce tensor. + for (const ir::Var& iter_var : sche_block->iter_vars) { + if (iter_var->is_reduce_axis) { + return false; + } + } std::set find_store = ir::ir_utils::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }); if (find_store.size() != 1UL) { @@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, return false; } + // the xxx_reduce_init block cannot be inlined. + if (ir::IsReduceInitTensorName(tensor->name)) { + return false; + } + + // Skip external calls + std::vector consumers = + ir::GetConsumers(sche_block_realize_expr, root); + for (const ir::Expr& consumer : consumers) { + std::set find_load = ir::ir_utils::CollectIRNodesWithoutTensor( + consumer.As() + ->schedule_block.As() + ->body, + [&](const ir::Expr* x) { + return x->As() && + x->As()->tensor.as_tensor_ref()->name == + tensor->name; + }); + if (find_load.empty()) { + return false; + } + } + // write_buffers.size() = 1 and read_buffers is empty, means const // we can inline to consumer if (sche_block->read_buffers.empty()) { diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h index 0ef60a01a9b0f6..9a0fc3e823361f 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h @@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule { std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; - private: void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT private: diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc index c44d067610123a..c8b8fdeb0f554d 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.cc @@ -161,15 +161,16 @@ void ReductionFactoring::Apply(const std::string& block_name, // 5. Split the reduction loop into 2 part VLOG(6) << "before Split: " << ir_schedule->GetModule().GetExprs()[0]; int factor = 1; + int max_factor = 1024; int extent = ir::GetLoopExtent(fused_reduce_loop); - for (int i = ceil(sqrt(extent)); i >= 1; --i) { + for (int i = max_factor; i >= 1; --i) { if (extent % i == 0) { factor = i; break; } } std::vector splited_reduction_loops = - ir_schedule->Split(fused_reduce_loop, {-1, factor}); + ir_schedule->Split(fused_reduce_loop, {factor, -1}); // 6. Apply FactorizeReduction VLOG(6) << "before FactorizeReduction: " << ir_schedule->GetModule().GetExprs()[0]; @@ -177,6 +178,25 @@ void ReductionFactoring::Apply(const std::string& block_name, num_spatial_loops); VLOG(6) << "after FactorizeReduction: " << ir_schedule->GetModule().GetExprs()[0]; + + // 7. Loop fusion and cross thread reduction + std::vector rb_loops = ir_schedule->GetLoops(block_name); + ir::Expr rf_block = ir_schedule->GetBlock(block_name + "_rf"); + ir_schedule->SimpleComputeAt(rf_block, rb_loops.back()); + + rb_loops = ir_schedule->GetLoops(block_name); + ir::Expr rf_init_block = + ir_schedule->GetBlock(block_name + "_rf__reduce_init"); + ir_schedule->SimpleComputeAt(rf_init_block, rb_loops.back()); + + if (*target_ == common::DefaultNVGPUTarget()) { + rb_loops = ir_schedule->GetLoops(block_name); + rf_block = ir_schedule->GetBlock(block_name + "_rf"); + ir_schedule->Bind(rb_loops.back(), "threadIdx.x"); + ir_schedule->SetBuffer(rf_block, "shared"); + } + VLOG(6) << "Loop fusion and cross thread reduction: " + << ir_schedule->GetModule().GetExprs()[0]; } } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc index 63e808cfbd4a50..6848fba586944e 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring_test.cc @@ -25,6 +25,8 @@ #include "paddle/cinn/ir/ir_printer.h" #include "test/cpp/cinn/concrete_program_builder.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace auto_schedule { @@ -37,7 +39,9 @@ class TestReductionFactoring : public TestAutoGenRuleBase { const std::vector& reduce_dim, const std::string& block_name, const std::string& expected_ir) { - Initialize(common::DefaultHostTarget()); + Initialize(common::DefaultNVGPUTarget()); + // In order to forcibly use the most basic Compute of reduction + FLAGS_cinn_new_group_scheduler = 1; auto test_program = tests::ReduceBuilder().Build( {{"X", shape}}, {{"reduce_dim", reduce_dim}}); // construct input parameter @@ -66,7 +70,8 @@ class TestReductionFactoring : public TestAutoGenRuleBase { }; TEST_F(TestReductionFactoring, AnalyseApplyType) { - Initialize(common::DefaultHostTarget()); + Context::Global().ResetNameId(); + Initialize(common::DefaultNVGPUTarget()); auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", {4, 5}}, {"Y", {4, 5}}}); ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); @@ -77,43 +82,44 @@ TEST_F(TestReductionFactoring, AnalyseApplyType) { RuleApplyType::kCannotApply); } +#ifdef CINN_WITH_CUDA + TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_0, 0, 8) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f } - serial for (reduce_k_0_1, 0, 8) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_1, 0, 1) { - vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1) - var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, ((8 * vreduce_k_0_0) + vreduce_k_0_1)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1) + var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_0, 0, 8) - { - ScheduleBlock(var_0) - { - vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0]) } } } @@ -124,42 +130,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) { } TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_reduce_k_1_fused, 0, 128) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f } - serial for (reduce_k_0_reduce_k_1_fused_0, 0, 64) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8) { - vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0) - var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0) + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_reduce_k_1_fused, 0, 128) - { - ScheduleBlock(var_0) - { - vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused]) } } } @@ -170,42 +175,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) { } TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) { + Context::Global().ResetNameId(); std::string expected_ir = R"({ ScheduleBlock(root) { { serial for (i, 0, 32) { - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512) + ScheduleBlock(var_0__reduce_init) + { + i0_0 = axis.bind(i) + var_0__reduce_init[i0_0] = 0.00000000f + } + thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024) { ScheduleBlock(var_0_rf__reduce_init) { vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f } - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 512) { - ScheduleBlock(var_0_rf) + serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256) { - vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0) - var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)]) + ScheduleBlock(var_0_rf) + { + vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0) + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)]) + } + } + { + ScheduleBlock(var_0) + { + vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) + var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused]) + } } - } - } - } - serial for (i, 0, 32) - { - ScheduleBlock(var_0__reduce_init) - { - i0_0 = axis.bind(i) - var_0__reduce_init[i0_0] = 0.00000000f - } - serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512) - { - ScheduleBlock(var_0) - { - vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i) - var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused]) } } } @@ -214,6 +218,7 @@ TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) { })"; TestApplyOnReduce({32, 64, 64, 64}, {1, 2, 3}, "var_0", expected_ir); } +#endif } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index c353eb3810ff89..be4c353e421ace 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -22,7 +22,8 @@ gather_srcs( op_lowering_impl.cc accuracy_checker.cc visualize_helper.cc - compile_error.cc) + compile_error.cc + group_scheduler.cc) # TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode @@ -43,6 +44,8 @@ endif() if(WITH_CUDA) cinn_cc_test(test_hlir_framework_op_lowering SRCS op_lowering_test.cc DEPS cinncore decomposer_test_helper) + cinn_cc_test(test_group_scheduler SRCS group_scheduler_test.cc DEPS cinncore + decomposer_test_helper) endif() cinn_cc_test(test_hlir_framework_tensor SRCS tensor_test.cc DEPS cinncore) cinn_cc_test(test_hlir_framework_scope SRCS scope_test.cc DEPS cinncore) diff --git a/paddle/cinn/hlir/framework/group_scheduler.cc b/paddle/cinn/hlir/framework/group_scheduler.cc new file mode 100644 index 00000000000000..2920a3f358c7eb --- /dev/null +++ b/paddle/cinn/hlir/framework/group_scheduler.cc @@ -0,0 +1,1212 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/group_scheduler.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" +#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h" +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace hlir { +namespace framework { + +static const std::unordered_set + kProhibitScheduleExternalFuncNames = { +#define CINN_NVGPU_FUNC2STRING(str) #str +#define CINN_NVGPU_FUNC_TYPE(FUNC, TYPE) \ + CINN_NVGPU_FUNC2STRING(cinn_nvgpu_##FUNC##TYPE) + +#define GEN_FUNC_NAME(_, impl) \ + _(impl, gt_num) \ + _(impl, lt_num) \ + _(impl, index_add) \ + _(impl, next_smallest) + +#define GEN_FUNC_NAME_WITH_TYPE(_, ...) \ + _(__VA_ARGS__, _bool), _(__VA_ARGS__, _fp16), _(__VA_ARGS__, _fp32), \ + _(__VA_ARGS__, _fp64), _(__VA_ARGS__, _uint8), _(__VA_ARGS__, _int8), \ + _(__VA_ARGS__, _int16), _(__VA_ARGS__, _int32), _(__VA_ARGS__, _int64), + + GEN_FUNC_NAME(GEN_FUNC_NAME_WITH_TYPE, CINN_NVGPU_FUNC_TYPE) +#undef GEN_FUNC_NAME +}; + +bool IsProhibitScheduleExternCallBlock(ir::Expr block) { + ir::ScheduleBlockRealize* sch_block_realize = + block.As(); + CHECK_NOTNULL(sch_block_realize); + ir::ScheduleBlock* sch_block = + sch_block_realize->schedule_block.As(); + CHECK_NOTNULL(sch_block); + + auto find_call = ir::ir_utils::CollectIRNodesWithoutTensor( + sch_block->body, [&](const Expr* x) { return x->As(); }); + for (ir::Expr call : find_call) { + ir::Call* call_node = call.As(); + if (call.As() && kProhibitScheduleExternalFuncNames.count( + call.As()->name) != 0) { + return true; + } + } + return false; +} + +// Find loops with same extents of 2 ScheduleBlock +std::vector> FindSameOuterLoops( + ir::ScheduleBlockNode* source_node, ir::ScheduleBlockNode* target_node) { + std::vector src_ctrl_stmts = source_node->ControlStmts(); + std::vector tgt_ctrl_stmts = target_node->ControlStmts(); + std::vector> same_loops; + int min_stmt_size = std::min(src_ctrl_stmts.size(), tgt_ctrl_stmts.size()); + for (int i = 0; i < min_stmt_size; ++i) { + if (src_ctrl_stmts[i].As() && tgt_ctrl_stmts[i].As() && + ir::GetLoopExtent(src_ctrl_stmts[i]) == + GetLoopExtent(tgt_ctrl_stmts[i])) { + same_loops.push_back( + std::make_tuple(src_ctrl_stmts[i], tgt_ctrl_stmts[i])); + } else { + break; + } + } + + return same_loops; +} + +std::unordered_set GetReduceLoopVarNames(ir::Expr block) { + ir::ScheduleBlockRealize* schedule_block_realize = + block.As(); + ir::ScheduleBlock* schedule_block = + schedule_block_realize->schedule_block.As(); + std::vector iter_values = schedule_block_realize->iter_values; + std::vector iter_vars = schedule_block->iter_vars; + std::unordered_set reduce_loop_var_names; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + ir::ir_utils::CollectIRNodesWithoutTensor( + iter_values[i], [&](const ir::Expr* x) { + if (x->as_var()) { + reduce_loop_var_names.insert(x->as_var_ref()->name); + } + return false; + }); + } + } + return reduce_loop_var_names; +} + +std::unordered_set GetReduceVarNames(ir::Expr block) { + ir::ScheduleBlockRealize* schedule_block_realize = + block.As(); + ir::ScheduleBlock* schedule_block = + schedule_block_realize->schedule_block.As(); + std::vector iter_vars = schedule_block->iter_vars; + std::unordered_set reduce_var_names; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + reduce_var_names.insert(iter_vars[i]->name); + } + } + return reduce_var_names; +} + +GroupScheduler::GroupScheduler(ir::IRSchedule* ir_sch, + const std::shared_ptr& group, + const common::Target& target) + : ir_sch_(ir_sch), group_(group), target_(target) { + schedule_block_graph_ = std::make_unique(*ir_sch_); +} + +void GroupScheduler::operator()() { + feasible_conditions_.emplace_back(&GroupScheduler::IsKeepGraphDependency); + DoLoopAlignment(); + DoComputeInline(); +#ifdef CINN_WITH_CUDA + OptimizeReduction(); +#endif + DoHorizontalLoopFusion(); + DoVerticalLoopFusion(); +#ifdef CINN_WITH_CUDA + BindCudaAxis(); + AllocateStorage(); +#endif +} + +NodePriority GroupScheduler::CalculateNodePriority( + const ir::ScheduleBlockNode* node) const { + bool has_loop_binded = false; + std::unordered_set reduce_loop_var_names = + GetReduceLoopVarNames(node->Block()); + + int64_t reduce_score = 1; + double score = 1; + for (Expr expr : node->ControlStmts()) { + ir::For* for_node = expr.As(); + if (for_node != nullptr) { + score *= ir::GetLoopExtent(expr); + } + if (reduce_loop_var_names.count(for_node->loop_var->name) != 0) { + reduce_score *= ir::GetLoopExtent(expr); + } + if (for_node->is_binded()) { + has_loop_binded = true; + } + } + if (reduce_score > 1) { + score *= (reduce_score * std::log2(reduce_score)); + } + + VLOG(6) << "The priority score of node " << node->id() << " is " << score; + VLOG(6) << "The node has_loop_binded: " << has_loop_binded; + return NodePriority{has_loop_binded, score}; +} + +ir::ScheduleBlockNode* GroupScheduler::FindGlobalMasterNode() const { + NodePriority max{false, std::numeric_limits::min()}; + ir::ScheduleBlockNode* master = nullptr; + auto FindMaster = [&](ir::ScheduleBlockNode* node) { + NodePriority priority = CalculateNodePriority(node); + VLOG(6) << "The priority score of node " << node->id() << " is " + << priority.score + << ", has_loop_binded: " << priority.has_loop_binded; + if (max < priority) { + max = priority; + master = node; + } + }; + + schedule_block_graph_->NodesWalk(FindMaster); + CHECK(master) << "Cannot find global master node"; + VLOG(6) << "Find the global master node: " << master->id(); + return master; +} + +std::unordered_set GroupScheduler::OutputTensorNames() const { + std::unordered_set output_tensor_names; + std::transform( + group_->output_nodes.begin(), + group_->output_nodes.end(), + std::inserter(output_tensor_names, output_tensor_names.begin()), + [](const Node* node) { + NodeData* node_data = + (*node->outlinks().begin())->sink()->safe_as(); + CHECK(node_data); + return node_data->id(); + }); + + for (ir::ScheduleBlockNode* node : schedule_block_graph_->EndPoints()) { + output_tensor_names.insert(node->id()); + } + return output_tensor_names; +} + +void GroupScheduler::DoLoopAlignment() { + VLOG(5) << "[Start LoopAlignment] func body: " + << ir_sch_->GetModule().GetExprs().front(); + ir::ScheduleBlockNode* global_master = FindGlobalMasterNode(); + ir::Expr master_block = global_master->Block(); + std::vector original_master_loop_extents; + std::vector spacial_master_loop_extents; + std::vector original_master_loop_order; + std::vector recover_loop_order; + + std::vector master_iter_values = + master_block.As()->iter_values; + std::vector master_iter_vars = + master_block.As() + ->schedule_block.As() + ->iter_vars; + std::vector master_loops = ir_sch_->GetLoops(master_block); + + std::unordered_set reduce_var_names = + GetReduceVarNames(master_block); + if (!reduce_var_names.empty()) { + std::set reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor( + master_block, + [&](const ir::Expr* x) { + bool find_reduce_var = false; + if (x->As()) { + int i = 0; + for (ir::Expr index : x->As()->indices) { + if (index.as_var() && + reduce_var_names.count(index.as_var_ref()->name) > 0) { + find_reduce_var = true; + } + ++i; + } + } + return find_reduce_var; + }, + /* uniq_target = */ true); + CHECK_EQ(reduce_loads.size(), 1); + + std::vector indices = + reduce_loads.begin()->As()->indices; + for (ir::Expr index : indices) { + CHECK_NOTNULL(index.as_var()); + int idx = 0; + bool is_reduce_var = false; + for (const ir::Var& iter_var : master_iter_vars) { + if (iter_var->name == index.as_var_ref()->name) { + is_reduce_var = iter_var->is_reduce_axis; + break; + } + ++idx; + } + std::vector loop_vars_in_order; + ir::ir_utils::CollectIRNodesInOrder( + master_iter_values[idx], [&](const ir::Expr* x) { + if (x->as_var()) { + loop_vars_in_order.push_back(x->as_var_ref()); + } + return false; + }); + for (const ir::Var& loop_var : loop_vars_in_order) { + for (int i = 0; i < master_loops.size(); ++i) { + if (master_loops[i].As()->loop_var->name == loop_var->name) { + original_master_loop_order.push_back(i); + int extent = ir::GetLoopExtent(master_loops[i]); + original_master_loop_extents.push_back(extent); + if (!is_reduce_var) { + spacial_master_loop_extents.push_back(extent); + } + } + } + } + } + + for (int i = 0; i < original_master_loop_order.size(); ++i) { + for (int j = 0; j < original_master_loop_order.size(); ++j) { + if (original_master_loop_order[j] == i) { + recover_loop_order.push_back(j); + break; + } + } + } + CHECK_EQ(original_master_loop_order.size(), recover_loop_order.size()); + } else { + for (int i = 0; i < master_loops.size(); ++i) { + original_master_loop_extents.push_back( + ir::GetLoopExtent(master_loops[i])); + spacial_master_loop_extents.push_back(ir::GetLoopExtent(master_loops[i])); + original_master_loop_order.push_back(i); + recover_loop_order.push_back(i); + } + } + + int total_master_loop_extents = 1; + int total_spacial_loop_extents = 1; + for (int extent : original_master_loop_extents) { + total_master_loop_extents *= extent; + } + for (int extent : spacial_master_loop_extents) { + total_spacial_loop_extents *= extent; + } + + auto LoopAlignmentFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return false; + } + + if (node == global_master) { + return false; + } + + for (ir::Expr expr : node->ControlStmts()) { + if (expr.As() != nullptr && + (expr.As()->for_type() == ir::ForType::GPUBlock || + expr.As()->for_type() == ir::ForType::GPUThread)) { + return false; + } + if (expr.As()->body.As() && + expr.As()->body.As()->stmts.size() > 1) { + return false; + } + } + + VLOG(6) << "try to align loops of block: " << node->id() + << " with block: " << global_master->id(); + + // 1. Fuse source loops + ir::Expr source_loop = ir_sch_->Fuse(node->ControlStmts()); + int total_source_extent = ir::GetLoopExtent(source_loop); + + // 2. Split source loop to align with the target loops + std::vector target_loop_extents; + if (total_source_extent < total_spacial_loop_extents) { + int cur_extent = 1; + for (int extent : spacial_master_loop_extents) { + cur_extent *= extent; + if (cur_extent == total_source_extent) { + target_loop_extents.push_back(extent); + break; + } else if (cur_extent > total_source_extent) { + target_loop_extents.push_back(-1); + break; + } else { + target_loop_extents.push_back(extent); + } + } + } else if (total_source_extent == total_spacial_loop_extents) { + target_loop_extents = spacial_master_loop_extents; + } else if (total_source_extent < total_master_loop_extents) { + target_loop_extents = spacial_master_loop_extents; + target_loop_extents.push_back(-1); + } else if (total_source_extent == total_master_loop_extents) { + target_loop_extents = original_master_loop_extents; + } + std::vector source_loops; + if (target_loop_extents.size() > 0 && + target_loop_extents[0] < total_source_extent) { + source_loops = ir_sch_->Split(source_loop, target_loop_extents); + } else { + source_loops = {source_loop}; + } + + // 3. Rerorder loops to match the target loops + if (total_source_extent == total_master_loop_extents) { + ir_sch_->Reorder(node->id(), recover_loop_order); + } + + return true; + }; + + schedule_block_graph_->DFSTopoWalk(LoopAlignmentFunc); + VLOG(5) << "[After LoopAlignment] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::DoComputeInline() { + VLOG(5) << "[Start DoComputeInline] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + std::unordered_set no_inline_output_names = OutputTensorNames(); + auto_schedule::AutoInline inliner(target_, no_inline_output_names); + + auto InlineFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try ComputeInline on: " << node->id() + << ", before ComputeInline, func body: " + << ir_sch_->GetModule().GetExprs().front(); + ir::Expr schedule_block = node->Block(); + inliner.Apply(ir_sch_, schedule_block); + VLOG(6) << "try ComputeInline on: " << node->id() + << ", after ComputeInline, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(InlineFunc); + schedule_block_graph_->Update(*ir_sch_); + VLOG(5) << "[After DoComputeInline] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::DoHorizontalLoopFusion() { + VLOG(5) << "[Start DoHorizontalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + std::vector end_nodes = + schedule_block_graph_->EndPoints(); + std::reverse(end_nodes.begin(), end_nodes.end()); + ir::ScheduleBlockNode* master_node = end_nodes.front(); + CHECK_NOTNULL(master_node); + for (int i = 1; i < end_nodes.size(); ++i) { + if (IsProhibitScheduleExternCallBlock(end_nodes[i]->Block())) { + continue; + } + VLOG(6) << "try to fuse loop of " << end_nodes[i]->id() << " to " + << master_node->id(); + std::vector>&& same_loops = + FindSameOuterLoops(end_nodes[i], master_node); + if (same_loops.size() == 0) { + continue; + } + ir::Expr target_loop = std::get<1>(same_loops.back()); + VLOG(6) << "target_loop: " << target_loop; + ir_sch_->SimpleComputeAt(end_nodes[i]->Block(), target_loop); + VLOG(6) << "after fuse: " << ir_sch_->GetModule().GetExprs().front(); + } + + VLOG(5) << "[After DoHorizontalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::DoVerticalLoopFusion() { + VLOG(5) << "[Start DoVerticalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); + UpdateBlockOrder(); + + auto FindMaster = + [&](ir::ScheduleBlockNode* node) -> std::vector { + std::vector masters = node->Consumers(); + std::sort( + masters.begin(), + masters.end(), + [&](const ir::ScheduleBlockNode* a, const ir::ScheduleBlockNode* b) { + return this->CalculateNodePriority(b) < + this->CalculateNodePriority(a); + }); + return masters; + }; + + auto ComputeAtFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + std::vector masters = FindMaster(node); + if (masters.size() == 0) { + return; + } + ir::Expr target_loop; + bool find_target_loop = false; + // Collect infomation of original loops + std::vector original_ctrl_stmts = node->ControlStmts(); + int64_t original_total_loop_extent = 1; + std::vector> original_loop_infos; + std::unordered_set original_loop_node_ptrs; + for (ir::Expr stmt : original_ctrl_stmts) { + if (stmt.As()) { + int extent = ir::GetLoopExtent(stmt); + original_total_loop_extent *= extent; + std::string thread_axis = ""; + ir::ForType target_for_type = stmt.As()->for_type(); + if (target_for_type == ir::ForType::GPUBlock) { + thread_axis += "blockIdx."; + } else if (target_for_type == ir::ForType::GPUThread) { + thread_axis += "threadIdx."; + } else { + original_loop_infos.push_back(std::make_pair(thread_axis, extent)); + continue; + } + int offset = stmt.As()->bind_info().offset; + thread_axis += ('x' + offset); + original_loop_infos.push_back(std::make_pair(thread_axis, extent)); + original_loop_node_ptrs.insert(stmt.ptr()); + } + } + + std::unordered_set src_reduce_loop_var_names = + GetReduceLoopVarNames(node->Block()); + for (ir::ScheduleBlockNode* master : masters) { + // Find the target loop candidates; + std::vector target_loop_candidates; + int64_t total_loop_extent = 1; + std::unordered_set tgt_reduce_loop_var_names = + GetReduceLoopVarNames(master->Block()); + std::vector> same_loops = + FindSameOuterLoops(node, master); + for (const std::tuple& same_loop : + same_loops) { + ir::Expr source_loop = std::get<0>(same_loop); + ir::Expr target_loop = std::get<1>(same_loop); + bool is_src_loop_reduce = + src_reduce_loop_var_names.count( + source_loop.As()->loop_var->name) > 0; + bool is_tgt_loop_reduce = + tgt_reduce_loop_var_names.count( + target_loop.As()->loop_var->name) > 0; + if (source_loop.ptr() != target_loop.ptr() && !is_src_loop_reduce && + !is_tgt_loop_reduce) { + target_loop_candidates.push_back(target_loop); + } + } + // Find the target loop with the highest priority and passing the + // feasibility condition check + for (std::vector::reverse_iterator iter = + target_loop_candidates.rbegin(); + iter != target_loop_candidates.rend(); + ++iter) { + ir::Expr candidate_loop = *iter; + if (candidate_loop.As() && + this->MeetConditions(node->Block(), candidate_loop, 0)) { + target_loop = candidate_loop; + find_target_loop = true; + break; + } + } + if (find_target_loop) { + VLOG(6) << "try to fuse loop of " << node->id() << " to " + << master->id(); + break; + } + } + + // Do schedule + if (find_target_loop) { + ir_sch_->SimpleComputeAt(node->Block(), target_loop); + VLOG(6) << "after compute at: " << ir_sch_->GetModule().GetExprs()[0]; + std::vector new_stmts = node->ControlStmts(); + for (int idx = 0; idx < original_loop_infos.size(); ++idx) { + if (original_loop_infos[idx].first.empty()) { + continue; + } + if (idx < new_stmts.size()) { + CHECK(new_stmts[idx].As()); + if (new_stmts[idx].As()->is_serial()) { + ir_sch_->Bind(new_stmts[idx], original_loop_infos[idx].first); + } + } else { + ir::Expr unit_loop = ir_sch_->AddUnitLoop(node->Block()); + ir_sch_->Bind(unit_loop, original_loop_infos[idx].first); + } + } + VLOG(6) << "after loop info copy: " << ir_sch_->GetModule().GetExprs()[0]; + // Update block and control stmts order after schedule. + this->UpdateBlockOrder(); + } else { + LOG(INFO) << "Cannot find a loop of masters to ComputeAt, do not merge.\n" + << "The schedule block: " << node->Block(); + } + }; + + schedule_block_graph_->DFSTopoWalk(ComputeAtFunc); + VLOG(5) << "[After DoVerticalLoopFusion] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::BindCudaAxis() { + if (target_.arch != Target::Arch::NVGPU) return; + VLOG(5) << "[Start BindCudaAxis] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + auto_schedule::AutoBind binder(target_); + + auto BindFunc = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try bind cuda axis on: " << node->id() + << ", before bind, func body: " + << ir_sch_->GetModule().GetExprs().front(); + binder.Apply(ir_sch_, node->id()); + VLOG(6) << "try bind cuda axis on: " << node->id() + << ", after bind, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(BindFunc); + + VLOG(5) << "[After BindCudaAxis] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +struct Range { + int min; + int max; +}; + +std::ostream& operator<<(std::ostream& os, const Range& x) { + os << "(" << x.min << ", " << x.max << ")"; + return os; +} + +// TODO(BiynXu): After implementing auxiliary data structures such as IntegerSet +// and MultiDimIntegerSet, re implement this function to simplify these ugly +// codes. +void GroupScheduler::AllocateStorage() { + if (target_.arch != Target::Arch::NVGPU) return; + VLOG(5) << "[Start AllocateStorage] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + // Record ir::For using index structure: > + std::unordered_map> + for_map; + std::unordered_set sync_mark; + + // function to update for_map + auto UpdateVarNameToForMap = [&](ir::Expr root) { + std::vector all_blocks = ir_sch_->GetAllBlocks(); + for (const ir::Expr& block : all_blocks) { + std::string block_name = block.As() + ->schedule_block.As() + ->name; + std::vector for_expr = ir_sch_->GetLoops(block); + for (ir::Expr for_expr : for_expr) { + for_map[block_name][for_expr.As()->loop_var->name] = for_expr; + VLOG(6) << "for_map.insert: <" << block_name << ", " + << for_expr.As()->loop_var->name << ">"; + } + } + }; + + // function to analyze and flatten indices to one dim of load_or_store node + auto AnalyzeIndiceValue = [](ir::Expr load_or_store, + ir::Expr block) -> ir::Expr { + std::vector indices; + ir::Tensor tensor; + if (load_or_store.As()) { + indices = load_or_store.As()->indices; + tensor = load_or_store.As()->tensor.as_tensor_ref(); + } else { + indices = load_or_store.As()->indices; + tensor = load_or_store.As()->tensor.as_tensor_ref(); + } + std::vector iter_vars = + block.As() + ->schedule_block.As() + ->iter_vars; + std::vector iter_values = + block.As()->iter_values; + struct VarHash { + size_t operator()(const ir::Var& var) const { + std::string name = var->name; + return std::hash()(name); + } + }; + std::vector strides; + int extent = 1; + for (int idx = tensor->shape.size() - 1; idx >= 0; --idx) { + strides.insert(strides.begin(), extent); + tensor->shape[idx] = common::AutoSimplify(tensor->shape[idx]); + CHECK(tensor->shape[idx].is_constant()) + << "Shape of tensor: " << tensor << " is not constant"; + extent *= tensor->shape[idx].get_constant(); + } + ir::Expr flatten_indice(0); + for (int idx = 0; idx < indices.size(); ++idx) { + flatten_indice = flatten_indice + ir::Expr(strides[idx]) * indices[idx]; + } + flatten_indice = common::AutoSimplify(flatten_indice); + for (int idx = 0; idx < iter_vars.size(); ++idx) { + optim::ReplaceVarWithExpr( + &flatten_indice, iter_vars[idx], iter_values[idx]); + } + flatten_indice = common::AutoSimplify(flatten_indice); + VLOG(6) << "flatten_indice of " << load_or_store << " : " << flatten_indice; + return flatten_indice; + }; + + enum class CudaBindInfo : int { + kCudaBlock, + kCudaThread, + kSerial, + kCudaThreadAndSerial, + }; + + // function to calculate the range of the specified CUDA axis in a indice + // expression + auto CalculateRange = [&for_map](ir::Expr indice_value, + const CudaBindInfo& bind_info, + const std::string& block_name) { + ir::Expr copy_for_upper_bound = ir::ir_utils::IRCopy(indice_value); + ir::Expr copy_for_lower_bound = ir::ir_utils::IRCopy(indice_value); + std::set var_set = ir::ir_utils::CollectIRNodesWithoutTensor( + indice_value, [](const ir::Expr* x) { return x->as_var(); }); + for (ir::Expr var : var_set) { + std::string name = var.as_var_ref()->name; + CHECK(for_map.find(block_name) != for_map.end()); + CHECK(for_map[block_name].find(name) != for_map[block_name].end()); + ir::Expr for_expr = for_map[block_name][name]; + if (bind_info == CudaBindInfo::kCudaBlock) { + if (for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kCudaThread) { + if (for_expr.As()->is_gpu_thread_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kSerial) { + if (!for_expr.As()->is_gpu_thread_binded() && + !for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } else if (bind_info == CudaBindInfo::kCudaThreadAndSerial) { + if (!for_expr.As()->is_gpu_block_binded()) { + optim::ReplaceVarWithExpr(©_for_upper_bound, + var.as_var_ref(), + for_expr.As()->min + + for_expr.As()->extent - + Expr(1)); + optim::ReplaceVarWithExpr(©_for_lower_bound, + var.as_var_ref(), + for_expr.As()->min); + } else { + optim::ReplaceVarWithExpr( + ©_for_upper_bound, var.as_var_ref(), ir::Expr(0)); + optim::ReplaceVarWithExpr( + ©_for_lower_bound, var.as_var_ref(), ir::Expr(0)); + } + } + } + VLOG(6) << "lower_bound before simplify of " << indice_value << " = " + << copy_for_lower_bound; + copy_for_lower_bound = + common::AutoSimplify(common::AutoSimplify(copy_for_lower_bound)); + VLOG(6) << "upper_bound before simplify of " << indice_value << " = " + << copy_for_upper_bound; + copy_for_upper_bound = + common::AutoSimplify(common::AutoSimplify(copy_for_upper_bound)); + VLOG(6) << "lower_bound of " << indice_value << " = " + << copy_for_lower_bound; + VLOG(6) << "upper_bound of " << indice_value << " = " + << copy_for_upper_bound; + return Range{static_cast(copy_for_lower_bound.get_constant()), + static_cast(copy_for_upper_bound.get_constant())}; + }; + + // function to calculate the coefficient and range of the specified for_type + // in a indice expression + auto GetCoefficientAndRange = [&for_map](ir::Expr indice_value, + const ir::ForType& for_type, + const std::string& block_name) { + std::vector> coef_and_ranges(3); + std::vector indice_copies; + for (int i = 0; i < 3; ++i) { + indice_copies.push_back(ir::ir_utils::IRCopy(indice_value)); + } + std::set var_set = ir::ir_utils::CollectIRNodesWithoutTensor( + indice_value, [](const ir::Expr* x) { return x->as_var(); }); + std::unordered_set visited_var_names; + for (ir::Expr var : var_set) { + std::string name = var.as_var_ref()->name; + if (visited_var_names.count(name) > 0) { + continue; + } + visited_var_names.insert(name); + CHECK(for_map.find(block_name) != for_map.end()); + CHECK(for_map[block_name].find(name) != for_map[block_name].end()); + ir::Expr for_expr = for_map[block_name][name]; + for (int i = 0; i < 3; ++i) { + if (for_type == for_expr.As()->for_type() && + for_expr.As()->bind_info().offset == i && + for_expr.As()->extent.get_constant() > 1) { + optim::ReplaceVarWithExpr( + &(indice_copies[i]), var.as_var_ref(), ir::Expr(1)); + coef_and_ranges[i].second.min = + for_expr.As()->min.get_constant(); + coef_and_ranges[i].second.max = + for_expr.As()->min.get_constant() + + for_expr.As()->extent.get_constant(); + } else { + optim::ReplaceVarWithExpr( + &(indice_copies[i]), var.as_var_ref(), ir::Expr(0)); + } + } + } + for (int i = 0; i < 3; ++i) { + VLOG(6) << "before simplify [" << i << "], the coefficient of " + << indice_value << " = " << indice_copies[i] << ", range = (" + << coef_and_ranges[i].second.min << ", " + << coef_and_ranges[i].second.max << ")"; + indice_copies[i] = common::AutoSimplify(indice_copies[i]); + VLOG(6) << "after simplify [" << i << "], the coefficient of " + << indice_value << " = " << indice_copies << ", range = (" + << coef_and_ranges[i].second.min << ", " + << coef_and_ranges[i].second.max << ")"; + coef_and_ranges[i].first = + static_cast(indice_copies[i].get_constant()); + } + return coef_and_ranges; + }; + + // Determine whether the indice of a pair of Store and Load cross CUDA threads + auto IsCrossThread = [&](ir::Expr store_indice_value, + ir::Expr load_indice_value, + const std::string& store_block_name, + const std::string& load_block_name) { + Range store_thread_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kCudaThread, store_block_name); + Range load_thread_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaThread, load_block_name); + Range store_serial_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kSerial, store_block_name); + Range load_serial_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kSerial, load_block_name); + auto store_thread_coefficient_and_range = GetCoefficientAndRange( + store_indice_value, ir::ForType::GPUThread, store_block_name); + auto load_thread_coefficient_and_range = GetCoefficientAndRange( + load_indice_value, ir::ForType::GPUThread, load_block_name); + VLOG(6) << "store_block_name: " << store_block_name + << ", load_block_name: " << load_block_name; + VLOG(6) << "store_indice_value: " << store_indice_value + << ", load_indice_value: " << load_indice_value; + VLOG(6) << "store_thread_overall_range = " << store_thread_overall_range; + VLOG(6) << "load_thread_overall_range = " << load_thread_overall_range; + VLOG(6) << "store_serial_overall_range = " << store_serial_overall_range; + VLOG(6) << "load_serial_overall_range = " << load_serial_overall_range; + VLOG(6) << "store_thread_coefficient_and_range[0] = <" + << store_thread_coefficient_and_range[0].first << ", " + << store_thread_coefficient_and_range[0].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[0] = <" + << load_thread_coefficient_and_range[0].first << ", " + << load_thread_coefficient_and_range[0].second << ">"; + VLOG(6) << "store_thread_coefficient_and_range[1] = <" + << store_thread_coefficient_and_range[1].first << ", " + << store_thread_coefficient_and_range[1].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[1] = <" + << load_thread_coefficient_and_range[1].first << ", " + << load_thread_coefficient_and_range[1].second << ">"; + VLOG(6) << "store_thread_coefficient_and_range[2] = <" + << store_thread_coefficient_and_range[2].first << ", " + << store_thread_coefficient_and_range[2].second << ">"; + VLOG(6) << "load_thread_coefficient_and_range[2] = <" + << load_thread_coefficient_and_range[2].first << ", " + << load_thread_coefficient_and_range[2].second << ">"; + return !(store_thread_overall_range.min <= load_thread_overall_range.min && + store_thread_overall_range.max >= load_thread_overall_range.max && + store_serial_overall_range.min <= load_serial_overall_range.min && + store_serial_overall_range.max >= load_serial_overall_range.max && + (store_thread_coefficient_and_range[0].first == + load_thread_coefficient_and_range[0].first || + load_thread_coefficient_and_range[0].first == 0) && + store_thread_coefficient_and_range[0].second.min <= + load_thread_coefficient_and_range[0].second.min && + store_thread_coefficient_and_range[0].second.max >= + load_thread_coefficient_and_range[0].second.max && + (store_thread_coefficient_and_range[1].first == + load_thread_coefficient_and_range[1].first || + load_thread_coefficient_and_range[1].first == 0) && + store_thread_coefficient_and_range[1].second.min <= + load_thread_coefficient_and_range[1].second.min && + store_thread_coefficient_and_range[1].second.max >= + load_thread_coefficient_and_range[1].second.max && + (store_thread_coefficient_and_range[2].first == + load_thread_coefficient_and_range[2].first || + load_thread_coefficient_and_range[2].first == 0) && + store_thread_coefficient_and_range[2].second.min <= + load_thread_coefficient_and_range[2].second.min && + store_thread_coefficient_and_range[2].second.max >= + load_thread_coefficient_and_range[2].second.max); + }; + + // Determine whether the indice of a pair of Store and Load cross CUDA block + auto IsCrossBlock = [&](ir::Expr store_indice_value, + ir::Expr load_indice_value, + const std::string& store_block_name, + const std::string& load_block_name) { + Range store_block_overall_range = CalculateRange( + store_indice_value, CudaBindInfo::kCudaBlock, store_block_name); + Range load_block_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaBlock, load_block_name); + Range store_thread_and_serial_overall_range = + CalculateRange(store_indice_value, + CudaBindInfo::kCudaThreadAndSerial, + store_block_name); + Range load_thread_and_serial_overall_range = CalculateRange( + load_indice_value, CudaBindInfo::kCudaThreadAndSerial, load_block_name); + auto store_block_coefficient_and_range = GetCoefficientAndRange( + store_indice_value, ir::ForType::GPUBlock, store_block_name); + auto load_block_coefficient_and_range = GetCoefficientAndRange( + load_indice_value, ir::ForType::GPUBlock, load_block_name); + VLOG(6) << "store_block_name: " << store_block_name + << ", load_block_name: " << load_block_name; + VLOG(6) << "store_indice_value: " << store_indice_value + << ", load_indice_value: " << load_indice_value; + VLOG(6) << "store_block_overall_range = " << store_block_overall_range; + VLOG(6) << "load_block_overall_range = " << load_block_overall_range; + VLOG(6) << "store_thread_and_serial_overall_range = " + << store_thread_and_serial_overall_range; + VLOG(6) << "load_thread_and_serial_overall_range = " + << load_thread_and_serial_overall_range; + VLOG(6) << "store_block_coefficient_and_range[0] = <" + << store_block_coefficient_and_range[0].first << ", " + << store_block_coefficient_and_range[0].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[0] = <" + << load_block_coefficient_and_range[0].first << ", " + << load_block_coefficient_and_range[0].second << ">"; + VLOG(6) << "store_block_coefficient_and_range[1] = <" + << store_block_coefficient_and_range[1].first << ", " + << store_block_coefficient_and_range[1].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[1] = <" + << load_block_coefficient_and_range[1].first << ", " + << load_block_coefficient_and_range[1].second << ">"; + VLOG(6) << "store_block_coefficient_and_range[2] = <" + << store_block_coefficient_and_range[2].first << ", " + << store_block_coefficient_and_range[2].second << ">"; + VLOG(6) << "load_block_coefficient_and_range[2] = <" + << load_block_coefficient_and_range[2].first << ", " + << load_block_coefficient_and_range[2].second << ">"; + return !(store_block_overall_range.min <= load_block_overall_range.min && + store_block_overall_range.max >= load_block_overall_range.max && + store_thread_and_serial_overall_range.min <= + load_thread_and_serial_overall_range.min && + store_thread_and_serial_overall_range.max >= + load_thread_and_serial_overall_range.max && + (store_block_coefficient_and_range[0].first == + load_block_coefficient_and_range[0].first || + load_block_coefficient_and_range[0].first == 0) && + store_block_coefficient_and_range[0].second.min <= + load_block_coefficient_and_range[0].second.min && + store_block_coefficient_and_range[0].second.max >= + load_block_coefficient_and_range[0].second.max && + (store_block_coefficient_and_range[1].first == + load_block_coefficient_and_range[1].first || + load_block_coefficient_and_range[1].first == 0) && + store_block_coefficient_and_range[1].second.min <= + load_block_coefficient_and_range[1].second.min && + store_block_coefficient_and_range[1].second.max >= + load_block_coefficient_and_range[1].second.max && + (store_block_coefficient_and_range[2].first == + load_block_coefficient_and_range[2].first || + load_block_coefficient_and_range[2].first == 0) && + store_block_coefficient_and_range[2].second.min <= + load_block_coefficient_and_range[2].second.min && + store_block_coefficient_and_range[2].second.max >= + load_block_coefficient_and_range[2].second.max); + }; + + // function to set storage of each tensor + auto SetStorage = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + ir::MemoryType memory_type = ir::MemoryType::GPULocal; + ir::Expr cur_block = node->Block(); + ir::Expr root_block = ir_sch_->GetRootBlock(cur_block); + UpdateVarNameToForMap(root_block); + std::vector consumer_blocks = + ir::GetConsumers(cur_block, root_block); + // find store and corresponding load nodes + ir::Expr find_store = + *ir::ir_utils::CollectIRNodesWithoutTensor( + cur_block, + [&](const ir::Expr* x) { return x->As(); }, + true) + .begin(); + ir::Expr store_indice_value = AnalyzeIndiceValue(find_store, cur_block); + std::vector> loads_and_blocks; + for (const ir::Expr& consumer_block : consumer_blocks) { + ir::ir_utils::CollectIRNodesWithoutTensor( + consumer_block, [&](const Expr* x) { + if (x->As() && (x->As()->name() == + find_store.As()->name())) { + loads_and_blocks.push_back(std::make_tuple(*x, consumer_block)); + } + return false; + }); + } + // Traverse load nodes to check if there are loads that cross cuda blocks or + // threads + for (const auto& load_and_block : loads_and_blocks) { + ir::Expr load = std::get<0>(load_and_block); + ir::Expr consumer_block = std::get<1>(load_and_block); + std::string consumer_block_name = + consumer_block.As() + ->schedule_block.As() + ->name; + ir::Expr load_indice_value = AnalyzeIndiceValue(load, consumer_block); + if (IsCrossBlock(store_indice_value, + load_indice_value, + node->id(), + consumer_block_name)) { + // TODO(BiynXu): Return error information to the front-end instead of + // terminating the program. + LOG(FATAL) << "Fusion requires synchronization across blocks, but " + "currently we do not support it."; + break; + } else if (IsCrossThread(store_indice_value, + load_indice_value, + node->id(), + consumer_block_name)) { + memory_type = ir::MemoryType::GPUShared; + } + } + // Set output node to global + std::unordered_set output_names = OutputTensorNames(); + if (output_names.count(node->id()) > 0) { + memory_type = ir::MemoryType::Auto; + } + // Set the reduce_init tensor and the real tensor to the same memory + if (ir::IsReduceInitTensorName(node->id())) { + ir::Expr block = + ir_sch_->GetBlock(ir::GetOriginalReduceTensorName(node->id())); + memory_type = ir::GetTensor(block)->buffer->memory_type; + } + // Do schedule + if (memory_type == ir::MemoryType::Auto) { + VLOG(6) << "Set store tensor of block " << node->id() << " to global"; + } else if (memory_type == ir::MemoryType::GPUShared) { + VLOG(6) << "Set store tensor of block " << node->id() << " to shared"; + ir_sch_->SetBuffer(cur_block, "shared"); + std::vector loops = ir_sch_->GetLoops(cur_block); + if (sync_mark.count(ir::GetOriginalReduceTensorName(node->id())) == 0) { + ir_sch_->SyncThreads(loops.back(), true); + sync_mark.insert(ir::GetOriginalReduceTensorName(node->id())); + } + } else if (memory_type == ir::MemoryType::GPULocal) { + VLOG(6) << "Set store tensor of block " << node->id() << " to register"; + ir_sch_->SetBuffer(cur_block, "local"); + } + }; + schedule_block_graph_->DFSTopoWalk(SetStorage); + VLOG(5) << "[After AllocateStorage] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::OptimizeReduction() { + VLOG(5) << "[Start OptimizeReduction] func body: " + << ir_sch_->GetModule().GetExprs().front(); + + auto_schedule::ReductionFactoring rf(target_); + + auto ReductionFactoring = [&](ir::ScheduleBlockNode* node) { + if (IsProhibitScheduleExternCallBlock(node->Block())) { + return; + } + VLOG(6) << "try ReductionFactoring on: " << node->id() + << ", before ReductionFactoring, func body: " + << ir_sch_->GetModule().GetExprs().front(); + rf.Apply(node->id(), ir_sch_); + VLOG(6) << "try ReductionFactoring on: " << node->id() + << ", after ReductionFactoring, func body: " + << ir_sch_->GetModule().GetExprs().front(); + }; + + schedule_block_graph_->DFSTopoWalk(ReductionFactoring); + schedule_block_graph_->Update(*ir_sch_); + + VLOG(5) << "[After OptimizeReduction] func body: " + << ir_sch_->GetModule().GetExprs().front(); +} + +void GroupScheduler::UpdateBlockOrder() { + ir::Expr root_block = ir_sch_->GetRootBlock(ir_sch_->GetAllBlocks()[0]); + ir::BlockOrderConstructor block_order_constructor; + blocks_order_with_ctrl_stmt_ = block_order_constructor(&root_block); +} + +bool GroupScheduler::IsKeepGraphDependency(Expr schedule_block, + Expr target_loop, + int insert_pos) const { + // Assuming inserting the schedule_block into the target_loop, + // obtain the transformed upstream and downstream blocks. + std::unordered_set blocks_above; + std::unordered_set blocks_below; + bool is_below = false; + bool find_target_loop = false; + int pos_count = -1; + std::map, ir::Expr>::const_iterator iter; + for (iter = blocks_order_with_ctrl_stmt_.begin(); + iter != blocks_order_with_ctrl_stmt_.end(); + ++iter) { + if (iter->second.get() == schedule_block.get()) { + continue; + } + if (iter->second.get() == target_loop.get()) { + find_target_loop = true; + } + if (find_target_loop) { + ++pos_count; + } + if (pos_count == insert_pos) { + is_below = true; + } + if (iter->second.As()) { + std::string block_id = iter->second.As() + ->schedule_block.As() + ->name; + if (is_below) { + blocks_below.insert(block_id); + } else { + blocks_above.insert(block_id); + } + } + } + + // Obtain real upstream and downstream nodes + std::string src_id = schedule_block.As() + ->schedule_block.As() + ->name; + ir::ScheduleBlockNode* node = schedule_block_graph_->RetrieveNode(src_id); + std::unordered_set upstream_ids = node->UpstreamNodes(); + std::unordered_set downstream_ids = node->DownstreamNodes(); + + // Check that the transformed upstream and downstream blocks + // still meet the relationship between the + // original upstream and downstream nodes. + for (const std::string& id : upstream_ids) { + if (blocks_above.count(id) == 0) { + VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id + << " cannot be insert into target loop at insert_pos: " + << insert_pos << " because its upstream block: " << id + << " will appear downstream."; + VLOG(6) << "The target loop:\n" << target_loop; + return false; + } + } + for (const std::string& id : downstream_ids) { + if (blocks_below.count(id) == 0) { + VLOG(6) << "[Breaking Graph Level Dependency] ScheduleBlock: " << src_id + << " cannot be insert into target loop at insert_pos: " + << insert_pos << " because its downstream block: " << id + << " will appear upstream."; + VLOG(6) << "The target loop:\n" << target_loop; + return false; + } + } + VLOG(6) << "[Meet Graph Level Dependency] ScheduleBlock: " << src_id + << " can be insert into target loop at insert_pos: " << insert_pos; + VLOG(6) << "The target loop:\n" << target_loop; + return true; +} + +bool GroupScheduler::MeetConditions(Expr schedule_block, + Expr target_loop, + int insert_pos) const { + for (const auto& condition_func : feasible_conditions_) { + if (!(this->*condition_func)(schedule_block, target_loop, insert_pos)) { + return false; + } + } + return true; +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/group_scheduler.h b/paddle/cinn/hlir/framework/group_scheduler.h new file mode 100644 index 00000000000000..90922e93b778ef --- /dev/null +++ b/paddle/cinn/hlir/framework/group_scheduler.h @@ -0,0 +1,171 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule_block_graph.h" + +namespace cinn { +namespace hlir { +namespace framework { + +// The priority of the ScheduleBlockNode, +// prioritizing whether it has been bound to the cuda axis, +// and secondly considering the amount of calculated data. +struct NodePriority { + bool has_loop_binded; + double score; + + bool operator<(const NodePriority& other) const { + if (has_loop_binded ^ other.has_loop_binded) { + return !has_loop_binded; + } else { + return score < other.score; + } + } +}; + +/** + * The class used for scheduling fusion groups. + * Its responsibility is to perform loop alignment, + * automatic inline, automatic loop fusion, + * and optimize the storage location of intermediate variables. + * Note: Currently only CUDA backend is supported. + */ +class GroupScheduler { + public: + GroupScheduler(ir::IRSchedule* ir_sch, + const std::shared_ptr& group, + const common::Target& target); + + void operator()(); + + private: + // Automatically align loops for each ScheduleBlock. + void DoLoopAlignment(); + + // Automatically inline some ScheduleBlock which meets the conditions. + void DoComputeInline(); + + // Make every effort to automatically merge the loops of the horizontal + // relationship ScheduleBlockNode. + void DoHorizontalLoopFusion(); + + // Make every effort to automatically merge the loops of the vertical + // relationship ScheduleBlockNode. + void DoVerticalLoopFusion(); + + // Automatically bind cuda axis on loops. + void BindCudaAxis(); + + // Automatically allocate storage locations for variables to optimize IO. + void AllocateStorage(); + + // Automatically optimize the reductive calculation + void OptimizeReduction(); + + // Evaluate the priority of ScheduleBlockNode. + // The node where the performance bottleneck is located + // has a higher priority, while the node with a lower priority + // needs to compromise and align loops with the node with the highest + // priority. + NodePriority CalculateNodePriority(const ir::ScheduleBlockNode* node) const; + + // Find the highest priority ScheduleBlockNode, + // other nodes need to align the loop with it. + ir::ScheduleBlockNode* FindGlobalMasterNode() const; + + // Obtain the latest order of ScheduleBlock and the control structures + // throughout the entire IR. + void UpdateBlockOrder(); + + // Get output tensor names of group. + std::unordered_set OutputTensorNames() const; + + /** + * @brief Determine whether the graph level dependency is still maintained + * after the schedule_block is placed in the insert position of target_loop. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + bool IsKeepGraphDependency(Expr schedule_block, + Expr target_loop, + int insert_pos) const; + + /** + * @brief Determine whether all feasible conditions are met + * after the schedule_block is placed in the insert position of target_loop. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + bool MeetConditions(Expr schedule_block, + Expr target_loop, + int insert_pos) const; + + private: + ir::IRSchedule* ir_sch_; + const std::shared_ptr& group_; + const common::Target& target_; + // Graph in units of ScheduleBlockNode, each node corresponds to a + // ScheduleBlock in IR. + std::unique_ptr schedule_block_graph_; + /** + * @brief Interface of feasibility condition. + * @param schedule_block The src schedule_block to be replaced. + * @param target_loop The target loop to be insert into the schedule_block. + * @param insert_pos The insert position of new schedule_block in the + * target_loop. + */ + + using FeasibleCondition = bool (GroupScheduler::*)(Expr schedule_block, + Expr target_loop, + int insert_pos) const; + // All feasible conditions. + std::vector feasible_conditions_; + + /** + * The order of blocks and their control statements, + * only For, IfThenElse and ScheduleBlock is considered. + * + * Example: + * for0: + * for1: + * block0 + * block1 + * block2 + * for2: + * block3 + * block4 + * + * the result is: + * [0]: for0 + * [0, 0]: for1 + * [0, 0, 0]: block0 + * [0, 0, 1]: block1 + * [0, 1]: block2 + * [0, 2]: for2 + * [0, 2, 0]: block3 + * [0, 2, 1]: block4 + */ + std::map, ir::Expr> blocks_order_with_ctrl_stmt_; +}; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/group_scheduler_test.cc b/paddle/cinn/hlir/framework/group_scheduler_test.cc new file mode 100644 index 00000000000000..24d5973eb122d2 --- /dev/null +++ b/paddle/cinn/hlir/framework/group_scheduler_test.cc @@ -0,0 +1,770 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/group_scheduler.h" + +#include + +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/frontend/decomposer/test_helper.h" +#include "paddle/cinn/hlir/framework/op_lowering.h" + +PD_DECLARE_bool(cinn_new_group_scheduler); + +namespace cinn { +namespace hlir { +namespace framework { + +using frontend::NetBuilder; +using frontend::RunDecomposer; + +void Compile(NetBuilder* net_builder) { + auto program = net_builder->Build(); + auto target = common::DefaultTarget(); + RunDecomposer(&program, target); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); + hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); + CHECK_EQ(graph->fusion_groups.size(), 1); + + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + graph->GetMutableAttrs>( + "infershape"); + + auto op_lowerer = + hlir::framework::CreateOpLowerer(dtype_dict, shape_dict, target); + for (auto& fusion_group : graph->fusion_groups) { + std::vector lowered_funcs = + op_lowerer.Lower(fusion_group, + /* apply_op_schedule = */ true, + /* apply_group_schedule = */ false); + CHECK_EQ(lowered_funcs.size(), 1); + VLOG(1) << "without group schedule, lowered_func: " + << lowered_funcs.front(); + + FLAGS_cinn_new_group_scheduler = true; + lowered_funcs = op_lowerer.Lower(fusion_group, + /* apply_op_schedule = */ true, + /* apply_group_schedule = */ true); + CHECK_EQ(lowered_funcs.size(), 1); + VLOG(1) << "after group schedule, lowered_func: " << lowered_funcs.front(); + } +} + +void CheckAccuracy(NetBuilder* net_builder, + const std::vector& input_names) { + FLAGS_cinn_new_group_scheduler = true; + auto program = net_builder->Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + + auto scope = BuildScope(target, graph); + hlir::framework::CompilationContext context(graph, scope, target); + hlir::framework::GraphCompiler gc(context); + + for (size_t i = 0; i < input_names.size(); ++i) { + scope->Var(input_names[i]); + auto tensor = scope->GetTensor(input_names[i]); + + std::vector vec; + frontend::InitRandomVector( + &vec, tensor->shape().numel(), 0.0f, 1.0f); + frontend::CopyFromVector(vec, tensor, target); + } + + auto runtime_program = gc.Build(); + runtime_program->Execute(); +} + +// Each unittest below tests a single reduce, +// these unittests are only used to observe the generated IR and debug. +// Accuracy testing is guaranteed by Python unittests named +// test_reduce_op_xxx.py. +TEST(GROUP_SCHEDULER, last_reduce_only_1) { + NetBuilder net_builder("last_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 64, 32}, "A"); + auto B = net_builder.ReduceSum(A, {2}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, last_reduce_only_2) { + NetBuilder net_builder("last_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {1024}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, last_reduce_only_3) { + NetBuilder net_builder("last_reduce_only_3"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {512, 256}, "A"); + auto B = net_builder.ReduceSum(A, {1}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_only_1) { + NetBuilder net_builder("non_last_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {10, 10, 10}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}, /* keep_dim = */ true); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_only_2) { + NetBuilder net_builder("non_last_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {64, 32, 16, 8, 4}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2, 3}, /* keep_dim = */ true); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, shuffle_reduce_only_1) { + NetBuilder net_builder("shuffle_reduce_only_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 32, 32, 32}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}); + }; + + CreateModel(); + Compile(&net_builder); +} + +TEST(GROUP_SCHEDULER, shuffle_reduce_only_2) { + NetBuilder net_builder("shuffle_reduce_only_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 64, 56, 56}, "A"); + auto B = net_builder.ReduceSum(A, {0, 2, 3}); + }; + + CreateModel(); + Compile(&net_builder); +} + +// Each of the following unittest tests a basic pattern composed of multiple +// basic op. And apply accuracy checks to ensure that the results of fusion +// groups and independently running each op are consistent. +TEST(GROUP_SCHEDULER, elementwise_1) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_1"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_2) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_2"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_3) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_3"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + auto F = net_builder.Cast(D, "float32"); + auto G = net_builder.Cast(E, "float32"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_4) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_4"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Cast(C, "float16"); + auto E = net_builder.Cast(C, "float16"); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_broadcast) { + NetBuilder net_builder("elementwise_broadcast"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128}, "A"); + auto B = net_builder.CreateInput(Float(32), {128}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.BroadcastTo(C, {128, 128}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_broadcast) { + NetBuilder net_builder("elementwise_double_broadcast"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128}, "A"); + auto B = net_builder.CreateInput(Float(32), {128}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.BroadcastTo(C, {128, 128}); + auto E = net_builder.BroadcastTo(C, {128, 128}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, non_last_reduce_elementwise_1) { + int h = 128, w = 128; + NetBuilder net_builder("non_last_reduce_elementwise_1"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.Cast(B, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, last_reduce_elementwise) { + NetBuilder net_builder("last_reduce_elementwise"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 64}, "A"); + auto B = net_builder.ReduceSum(A, {1}); + auto C = net_builder.CreateInput(Float(32), {128}, "C"); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_1) { + NetBuilder net_builder("keep_dim_reduce_elementwise"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A"); + auto B = net_builder.CreateInput(Float(32), {1, 64, 1, 1}, "B"); + auto C = net_builder.ReduceSum(A, {0, 2, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_2) { + NetBuilder net_builder("keep_dim_reduce_elementwise_2"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 112, 112}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 64, 1, 1}, "B"); + auto C = net_builder.ReduceSum(A, {2, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_3) { + NetBuilder net_builder("keep_dim_reduce_elementwise_3"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 64, 1}, "B"); + auto C = net_builder.ReduceSum(A, {2}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_4) { + NetBuilder net_builder("keep_dim_reduce_elementwise_4"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 2048}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 1, 2048}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, keep_dim_reduce_elementwise_5) { + NetBuilder net_builder("keep_dim_reduce_elementwise_5"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {16, 64, 16, 1024}, "A"); + auto B = net_builder.CreateInput(Float(32), {16, 1, 16, 1}, "B"); + auto C = net_builder.ReduceSum(A, {1, 3}, true); + auto D = net_builder.Add(B, C); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_non_last_reduce) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_last_reduce) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_last_reduce"); + std::vector input_names = {"A", "C"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.Cast(E, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {1}); + auto F = net_builder.Cast(E, "float16"); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("elementwise_double_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.ReduceSum(C, {0}); + auto G = net_builder.Add(E, F); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, double_non_last_reduce_elementwise) { + int h = 128, w = 128; + NetBuilder net_builder("double_non_last_reduce_elementwise"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h * 2, w}, "B"); + auto E = net_builder.ReduceSum(A, {0}); + auto F = net_builder.ReduceSum(B, {0}); + auto G = net_builder.Add(E, F); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, triple_non_last_reduce) { + int h = 128, w = 1024; + NetBuilder net_builder("triple_non_last_reduce"); + std::vector input_names = {"A", "B"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {128, 1024}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.ReduceSum(A, {0}); + auto D = net_builder.ReduceSum(A, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_1) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h * w}, "A"); + auto B = net_builder.ReduceSum(A, {0}); + auto C = net_builder.BroadcastTo(B, {h * w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_2) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {0, 1}); + auto C = net_builder.BroadcastTo(B, {h, w}, {1}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_3) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_3"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_reduce_broadcast) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_reduce_broadcast"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.ReduceSum(C, {1, 2}); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {0}); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, reduce_broadcast_elementwise) { + int h = 32, w = 32; + NetBuilder net_builder("reduce_broadcast_elementwise"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1, 2}); + auto C = net_builder.BroadcastTo(B, {h, h, w}, {0}); + auto D = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto E = net_builder.BroadcastTo(D, {h, h, w}, {1, 2}); + auto F = net_builder.Add(C, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_1) { + NetBuilder net_builder("elementwise_double_reduce_elementwise_1"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {32, 32}, "A"); + auto B = net_builder.CreateInput(Float(32), {32, 32}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}, false); + auto E = net_builder.ReduceSum(C, {1}, false); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, elementwise_double_reduce_elementwise_2) { + NetBuilder net_builder("elementwise_double_reduce_elementwise_2"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + auto A = net_builder.CreateInput(Float(32), {1, 1000}, "A"); + auto B = net_builder.CreateInput(Float(32), {1, 1000}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {1}, false); + auto E = net_builder.ReduceSum(C, {1}, false); + auto F = net_builder.Add(D, E); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +// Each of following unittests tests a group composed of typical operators +TEST(GROUP_SCHEDULER, layernorm) { + int h = 32, w = 1024; + NetBuilder net_builder("layernorm"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + // x + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // x * x + auto B = net_builder.Multiply(A, A); + // sum x + auto C = net_builder.ReduceSum(A, {1}); + // sum x*x + auto D = net_builder.ReduceSum(B, {1}); + // constant w + auto E = net_builder.FillConstant({h}, 1024.0f, "E"); + // mean + auto F = net_builder.Divide(C, E); + auto FF = net_builder.BroadcastTo(F, {h, w}, {0}); + // mean x*x + auto G = net_builder.Divide(D, E); + // mean * mean + auto H = net_builder.Multiply(F, F); + // var^2 + auto I = net_builder.Subtract(G, H); + // eps + auto J = net_builder.FillConstant({h}, 1e-10f, "J"); + // eps + delta + auto K = net_builder.Add(I, J); + // var + auto L = net_builder.Sqrt(K); + auto LL = net_builder.BroadcastTo(L, {h, w}, {0}); + // x - mean + auto M = net_builder.Subtract(A, FF); + // /var + auto N = net_builder.Divide(M, LL); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +TEST(GROUP_SCHEDULER, softmax) { + int h = 32, w = 1024; + NetBuilder net_builder("softmax"); + std::vector input_names = {"A"}; + // create model + auto CreateModel = [&]() { + // softmax + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + // reduce max + auto B = net_builder.ReduceMax(A, {1}); + // broadcast + auto C = net_builder.BroadcastTo(B, {h, w}, {0}); + // x - max(x) + auto D = net_builder.Subtract(A, C); + // exp(x) + auto E = net_builder.Exp(D); + // reduce sum + auto F = net_builder.ReduceSum(E, {1}); + // broadcast + auto G = net_builder.BroadcastTo(F, {h, w}, {0}); + // exp(x)/sum(exp(x)) + auto H = net_builder.Divide(E, G); + }; + + CreateModel(); + Compile(&net_builder); + CreateModel(); + CheckAccuracy(&net_builder, input_names); +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.cc b/paddle/cinn/hlir/framework/op_lowering_impl.cc index cc5d88554432ca..a8006e147946e7 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/op_lowering_impl.cc @@ -17,12 +17,15 @@ #include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/graph_compiler_util.h" +#include "paddle/cinn/hlir/framework/group_scheduler.h" #include "paddle/cinn/hlir/framework/op_lowering_util.h" #include "paddle/cinn/hlir/op/external_api_registry.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/cinn/runtime/flags.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); +PD_DECLARE_bool(cinn_new_group_scheduler); namespace cinn { namespace hlir { @@ -123,7 +126,10 @@ std::vector OpLowererImpl::LowerGroup( ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (apply_group_schedule && + !(nodes.size() == 1 && + op_pattern_dict[nodes[0]->op()] == OpPatternKind::kNonFusible)) { DoGroupSchedule(ir_sch, group, tensor_map); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); @@ -463,6 +469,11 @@ ir::Expr OpLowererImpl::DoGroupSchedule( ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map) { + if (FLAGS_cinn_new_group_scheduler) { + GroupScheduler group_scheduler(&ir_sch, group, target_); + group_scheduler(); + return ir_sch.GetModule().GetExprs().at(0); + } // topological order. auto nodes_set = group->NodeSet(); auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); diff --git a/paddle/cinn/hlir/op/contrib/argmax.cc b/paddle/cinn/hlir/op/contrib/argmax.cc index 2a1f19a5d2608a..041cfe7dc47a50 100644 --- a/paddle/cinn/hlir/op/contrib/argmax.cc +++ b/paddle/cinn/hlir/op/contrib/argmax.cc @@ -161,6 +161,25 @@ std::shared_ptr StrategyForArgmax( ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); + int iter_var_size = blocks[0] + .As() + ->schedule_block.As() + ->iter_vars.size(); + int real_axis = axis; + if (real_axis < 0) { + real_axis += iter_var_size; + } + blocks[0] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + blocks[1] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + int64_t prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, diff --git a/paddle/cinn/hlir/op/contrib/argmin.cc b/paddle/cinn/hlir/op/contrib/argmin.cc index dfd88deb6f380a..3caaf45c46a5eb 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.cc +++ b/paddle/cinn/hlir/op/contrib/argmin.cc @@ -158,6 +158,26 @@ std::shared_ptr StrategyForArgmin( // variables, because the size will exceed the limit. ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); + + int iter_var_size = blocks[0] + .As() + ->schedule_block.As() + ->iter_vars.size(); + int real_axis = axis; + if (real_axis < 0) { + real_axis += iter_var_size; + } + blocks[0] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + blocks[1] + .As() + ->schedule_block.As() + ->iter_vars[real_axis] + ->is_reduce_axis = true; + int64_t prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index a396aec315af46..0c279737a2a72f 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -29,6 +29,8 @@ #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/ir_simplify.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace hlir { namespace op { @@ -58,7 +60,7 @@ std::shared_ptr StrategyForReduce( const std::string &op_name, BlockReduceFunc gpu_reduce_with_last_axis_func, BlockReduceFunc gpu_reduce_without_last_axis_func, - ReduceFunc cpu_reduce_func) { + ReduceFunc common_reduce_func) { std::vector reduce_axes; auto ndim = inputs[0]->shape.size(); if (attrs.attr_store.count("dim")) { @@ -127,7 +129,8 @@ std::shared_ptr StrategyForReduce( << "The type of input argument " << x->name << " of " << op_name << " should be bool, but get " << x->type() << "! Please check."; - if (target == common::DefaultNVGPUTarget()) { + if (!FLAGS_cinn_new_group_scheduler && + target == common::DefaultNVGPUTarget()) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { VLOG(3) << "Do Two Step Block Reduce Compute!"; auto res = gpu_reduce_with_last_axis_func( @@ -155,7 +158,7 @@ std::shared_ptr StrategyForReduce( } } else { VLOG(3) << "Do Reduce Compute!"; - auto out = cpu_reduce_func(x, reduce_axes, keep_dim, tensor_name); + auto out = common_reduce_func(x, reduce_axes, keep_dim, tensor_name); auto stages = CreateStages({out}); std::vector cinn_values{CINNValue(out), CINNValue(stages)}; @@ -193,7 +196,7 @@ std::shared_ptr StrategyForReduce( ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - if (target.arch == Target::Arch::NVGPU) { + if (!FLAGS_cinn_new_group_scheduler && target.arch == Target::Arch::NVGPU) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { if (arg_pack.size() == 4) { CHECK_EQ(vec_tensor.size(), 2); @@ -313,7 +316,7 @@ std::shared_ptr StrategyForReduce( reduce_op_, \ gpu_reduce_with_last_axis_func, \ gpu_reduce_without_last_axis_func, \ - cpu_reduce_func) \ + common_reduce_func) \ std::shared_ptr StrategyFor##reduce_op_( \ const framework::NodeAttr &attrs, \ const std::vector &inputs, \ @@ -328,7 +331,7 @@ std::shared_ptr StrategyForReduce( #op_name_, \ gpu_reduce_with_last_axis_func, \ gpu_reduce_without_last_axis_func, \ - cpu_reduce_func); \ + common_reduce_func); \ } STRATEGY_FOR_REDUCE(reduce_sum, diff --git a/paddle/cinn/hlir/op/reduction_test.cc b/paddle/cinn/hlir/op/reduction_test.cc index ca20c0d3fdd769..953dd82017d9bd 100644 --- a/paddle/cinn/hlir/op/reduction_test.cc +++ b/paddle/cinn/hlir/op/reduction_test.cc @@ -39,6 +39,9 @@ #include "paddle/cinn/hlir/pe/nn.h" #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/cuda/cuda_module.h" + +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace hlir { namespace framework { @@ -362,6 +365,9 @@ void TestCaseForReduce(const float init_val, dim3 block; grid = {c, 1, 1}; int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h; + if (FLAGS_cinn_new_group_scheduler) { + block_dim_x = 1; + } block = {block_dim_x, 1, 1}; void* args[] = {&dev_x, &dev_z}; @@ -531,7 +537,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); - CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce) { @@ -544,7 +551,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce) { std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); - CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { @@ -558,7 +566,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); - CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { @@ -572,7 +581,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); - CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); + if (!FLAGS_cinn_new_group_scheduler) + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index 6600905b083c1f..b8f6d170996b38 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -31,14 +31,39 @@ #include "paddle/cinn/hlir/pe/schedule.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/utils/string.h" +PD_DECLARE_bool(cinn_new_group_scheduler); namespace cinn { namespace hlir { namespace pe { +void SetReduceAxis(ir::Expr loop, ir::Expr block) { + std::string var_name = loop.As()->loop_var->name; + std::vector iter_vars = block.As() + ->schedule_block.As() + ->iter_vars; + std::vector iter_values = + block.As()->iter_values; + CHECK_EQ(iter_vars.size(), iter_values.size()); + for (int i = 0; i < iter_values.size(); ++i) { + std::set contains = ir::ir_utils::CollectIRNodesWithoutTensor( + iter_values[i], + [&var_name](const Expr *expr) { + return expr->As() != nullptr && + expr->As()->name == var_name; + }, + true); + if (!contains.empty()) { + iter_vars[i]->is_reduce_axis = true; + } + } +} + void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT const std::vector &output_shape, const common::Target &target) { @@ -457,9 +482,15 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT if (loops_tmp_out.size() == 1) { ir_sch.Bind(loops_tmp_out[0], "threadIdx.x"); ir_sch.Bind(loops_out[0], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops_tmp_out[0], ir_sch.GetBlock(tmp_out->name)); + } } else { ir_sch.Bind(loops_tmp_out[0], "blockIdx.x"); ir_sch.Bind(loops_tmp_out[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops_tmp_out[1], ir_sch.GetBlock(tmp_out->name)); + } if (loops_out.size() == 1) { ir_sch.Split(loops_out[0], {-1, 1}); @@ -471,7 +502,11 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT for (auto &tensor : {tmp_out}) { auto block = ir_sch.GetBlock(tensor->name); - ir_sch.SetBuffer(block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(block, "local"); + } else { + ir_sch.SetBuffer(block, "local", true); + } } VLOG(3) << "After IRCudaScheduleBlockReduceInternal : " @@ -600,6 +635,9 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler) { + SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name)); + } } // out { @@ -614,7 +652,11 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT for (auto &tensor : {reduce_tmp_out, tmp_out}) { auto block = ir_sch.GetBlock(tensor->name); - ir_sch.SetBuffer(block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(block, "local"); + } else { + ir_sch.SetBuffer(block, "local", true); + } } VLOG(3) << "After IRCudaScheduleBlockReduce : " @@ -673,8 +715,10 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT auto load = exprs.front().As(); load->indices = {index}; }; - hand_write_simplify(ir_sch.GetLoops(reshape->name), - ir_sch.GetBlock(reshape->name)); + if (!FLAGS_cinn_new_group_scheduler) { + hand_write_simplify(ir_sch.GetLoops(reshape->name), + ir_sch.GetBlock(reshape->name)); + } auto block = ir_sch.GetBlock(reshape->name); ir_sch.ComputeInline(block); VLOG(4) << "After simplify reshape index : " @@ -955,10 +999,14 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT ir_sch.ComputeInline(reshape_block); auto internal_block = ir_sch.GetBlock(internal->name); - ir_sch.SetBuffer(internal_block, "local", true); - auto tmp_out_block = ir_sch.GetBlock(tmp_out->name); - ir_sch.SetBuffer(tmp_out_block, "local", true); + if (FLAGS_cinn_new_group_scheduler) { + ir_sch.SetBuffer(internal_block, "local"); + ir_sch.SetBuffer(tmp_out_block, "local"); + } else { + ir_sch.SetBuffer(internal_block, "local", true); + ir_sch.SetBuffer(tmp_out_block, "local", true); + } // The current one-dimensional reduce does not make full use of SM. // This case is optimized into a two-dimensional. @@ -978,9 +1026,15 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.y"); ir_sch.Bind(loops[2], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) { + SetReduceAxis(loops[2], ir_sch.GetBlock(tmp_out->name)); + } } else { ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); + if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) { + SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name)); + } } } VLOG(3) << "After IRCudaTwoStepReduceSchedule : " diff --git a/paddle/cinn/ir/schedule/factorize_reduction.h b/paddle/cinn/ir/schedule/factorize_reduction.h index 0973d123fd40c1..4075feb93599e0 100644 --- a/paddle/cinn/ir/schedule/factorize_reduction.h +++ b/paddle/cinn/ir/schedule/factorize_reduction.h @@ -33,7 +33,7 @@ namespace ir { Tensor CreateRFTensor(const Tensor& original_tensor, const Expr& rf_loop, int rf_axis) { - std::string name = original_tensor->name + "_rf"; + std::string name = common::UniqName(original_tensor->name + "_rf"); std::vector new_shape = original_tensor->shape; new_shape.insert(new_shape.begin() + rf_axis, rf_loop.As()->extent); Tensor rf_tensor = _Tensor_::Make(name, @@ -80,19 +80,23 @@ class ReduceBlockCreater { ->schedule_block.As() ->name; if (is_rf_block_) { - new_update_block_name += "_rf"; + new_update_block_name = rf_tensor_->name; } std::string new_init_block_name = ir::GenReduceInitTensorNameOf(new_update_block_name); VLOG(5) << "new_init_block_name = " << new_init_block_name; - Expr init_value = rf_tensor_->GetReduceInitVal(); - const std::vector& domain = rf_tensor_->domain_without_reduce_axis(); + const ir::Tensor& real_tensor = + is_rf_block_ + ? rf_tensor_ + : original_update_stmt_.As()->tensor.as_tensor_ref(); + Expr init_value = real_tensor->GetReduceInitVal(); + const std::vector& domain = real_tensor->domain_without_reduce_axis(); ir::Tensor init_tensor = lang::Compute( domain, [=](const std::vector& axis) { return init_value; }, new_init_block_name); - init_tensor->Bind(rf_tensor_->buffer); + init_tensor->Bind(real_tensor->buffer); Expr init_stmt = ir::Store::Make( init_tensor, init_value, new_update_stmt_.As()->indices); new_init_sch_block_ = ScheduleBlock::Make( @@ -299,6 +303,12 @@ class RFBlockCreater : public ReduceBlockCreater { REPLACE_RF_TENSOR(Mul) REPLACE_RF_TENSOR(Max) REPLACE_RF_TENSOR(Min) + REPLACE_RF_TENSOR(And) + REPLACE_RF_TENSOR(Or) + REPLACE_RF_TENSOR(LT) + REPLACE_RF_TENSOR(LE) + REPLACE_RF_TENSOR(GT) + REPLACE_RF_TENSOR(GE) #undef REPLACE_RF_TENSOR new_update_stmt_ = @@ -388,6 +398,12 @@ class RBBlockCreater : public ReduceBlockCreater { REPLACE_RF_TENSOR(Mul) REPLACE_RF_TENSOR(Max) REPLACE_RF_TENSOR(Min) + REPLACE_RF_TENSOR(And) + REPLACE_RF_TENSOR(Or) + REPLACE_RF_TENSOR(LT) + REPLACE_RF_TENSOR(LE) + REPLACE_RF_TENSOR(GT) + REPLACE_RF_TENSOR(GE) #undef REPLACE_RF_TENSOR Expr original_store_tensor = original_update_stmt_.As()->tensor; diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index bdafc61ff88151..2baebcbacc61b8 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -2640,6 +2640,13 @@ void IRSchedule::SetBuffer(Expr& block, {})); } +Expr IRSchedule::AddUnitLoop(const Expr& block) { + Expr ret = impl_->AddUnitLoop(block); + trace_.Append(ScheduleDesc::Step( + "AddUnitLoop", {{"block", std::vector({block})}}, {}, {ret})); + return ret; +} + Expr IRSchedule::Reorder(const std::vector& loops) { Expr ret = impl_->Reorder(loops); trace_.Append(ScheduleDesc::Step("Reorder", {{"loops", loops}}, {}, {ret})); diff --git a/paddle/cinn/ir/schedule/ir_schedule.h b/paddle/cinn/ir/schedule/ir_schedule.h index 4c5fc1d10f1b69..b33afd03a799a8 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.h +++ b/paddle/cinn/ir/schedule/ir_schedule.h @@ -244,7 +244,7 @@ class IRSchedule { */ void SyncThreads(const Expr& ir_node, bool after_node = true); - /*! + /** * \brief Set a tensor's buffer type(memory_type) * \param block The ScheduleBlockRealize corresponding to an unique tensor. * \param memory_type The memory type we want to set. Should be "local", @@ -254,6 +254,13 @@ class IRSchedule { const std::string& memory_type, bool fixed = false); // NOLINT + /** + * \brief Create a new unit loop on top of the block. + * @param block The block to be added the new loop. + * @return The new unit loop. + */ + Expr AddUnitLoop(const Expr& block); + /** * \brief Reorder the loops in the order of vector. * @param loops The loops to be reordered. diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index 7a2daa3106612f..db378eba741945 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -367,8 +367,16 @@ IterRange GetAccessedRange(const Expr& index, Expr indice_extent; Expr mod_extent(0); - if (indice_min.As() && indice_min.As()->b().is_constant()) + if (indice_min.As() && indice_min.As()->b().is_constant()) { + Expr mod_right_min = indice_min.As()->a(); + Expr mod_right_max = indice_max.As()->a(); + Expr mod_right_extent = + common::AutoSimplify(mod_right_max - mod_right_min + 1); mod_extent = indice_min.As()->b(); + if (mod_right_extent.get_constant() < mod_extent.get_constant()) { + mod_extent = mod_right_extent; + } + } if (indice_min == indice_max) { if (common::is_zero(mod_extent)) { @@ -875,7 +883,7 @@ std::vector GetProducers(const Expr& block, const Expr& root) { ->name; ir::ir_utils::CollectIRNodesWithoutTensor( compute_body, [&producer_tensor_names, &block_name](const Expr* x) { - auto* load = x->As(); + const ir::Load* load = x->As(); if (load) { producer_tensor_names.insert(load->tensor.as_tensor()->name); if (load->tensor.as_tensor()->name == block_name) { @@ -884,6 +892,22 @@ std::vector GetProducers(const Expr& block, const Expr& root) { } return true; } + const ir::Store* store = x->As(); + if (store) { + std::set call_nodes = + ir::ir_utils::CollectIRNodesWithoutTensor( + store->value, + [](const ir::Expr* x) { return x->As(); }); + for (ir::Expr call : call_nodes) { + const std::vector& read_args = + call.As()->read_args; + for (const ir::Expr& arg : read_args) { + if (arg.as_tensor()) { + producer_tensor_names.insert(arg.as_tensor_ref()->name); + } + } + } + } return false; }); @@ -936,13 +960,23 @@ std::vector GetConsumers(const Expr& block, const Expr& root) { auto block_body = i.As() ->schedule_block.As() ->body; - auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor( + auto find_load_or_call = ir::ir_utils::CollectIRNodesWithoutTensor( block_body, [&](const Expr* x) { + if (x->As()) { + const std::vector& read_args = + x->As()->read_args; + for (const ir::Expr& arg : read_args) { + if (arg.as_tensor() && + arg.as_tensor_ref()->name == block_tensor) { + return true; + } + } + } return x->As() && x->As()->tensor.as_tensor_ref()->name == block_tensor; }); - if (!find_load.empty()) consumers.emplace_back(i); + if (!find_load_or_call.empty()) consumers.emplace_back(i); } return consumers; } diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.h b/paddle/cinn/ir/schedule/ir_schedule_util.h index 9c9418b4d577ec..edd202c6093d63 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.h +++ b/paddle/cinn/ir/schedule/ir_schedule_util.h @@ -436,9 +436,11 @@ IterRange RangeUnion(const IterRange& range1, const IterRange& range2); * \param loop The loop where we will insert the block under it * @param root The root of the whole AST. * \param required_blocks vector of ScheduleBlockRealize nodes that require the - * block \param is_store_provided Whether Store nodes of the block provide the + * block + * \param is_store_provided Whether Store nodes of the block provide the * tensor, true means it is in compute_at case, otherwise false means in - * reverse_compuate_at case \return Each index's range of block's tensor. + * reverse_compuate_at case + * \return Each index's range and can_keep_loop flag of block's tensor. * Indicating the buffer region being required. */ std::vector CalculateRequiredRegions( diff --git a/paddle/cinn/ir/schedule/schedule_desc.cc b/paddle/cinn/ir/schedule/schedule_desc.cc index e0d5f4ab217018..c9a26dfa1643d6 100644 --- a/paddle/cinn/ir/schedule/schedule_desc.cc +++ b/paddle/cinn/ir/schedule/schedule_desc.cc @@ -422,6 +422,12 @@ CINN_BUILD_STEP_KIND(SetBuffer) .SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SetBuffer))); +CINN_BUILD_STEP_KIND(AddUnitLoop) + .Inputs({"block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM( + FREE_FUNCTION_CONVERTER(static_cast( + &IRSchedule::AddUnitLoop)))); + CINN_BUILD_STEP_KIND(Reorder).Inputs({"loops"}).SetApplyFn( APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>( diff --git a/paddle/cinn/ir/test/schedule_block_graph_test.cc b/paddle/cinn/ir/test/schedule_block_graph_test.cc index 20c7f03b4d235d..78c809dc117d46 100644 --- a/paddle/cinn/ir/test/schedule_block_graph_test.cc +++ b/paddle/cinn/ir/test/schedule_block_graph_test.cc @@ -20,6 +20,8 @@ #include "paddle/cinn/hlir/framework/op_lowering.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +PD_DECLARE_bool(cinn_new_group_scheduler); + namespace cinn { namespace ir { @@ -95,6 +97,7 @@ frontend::Program CreateReduceProgram() { } TEST(ScheduleBlockGraph, elementwise) { + Context::Global().ResetNameId(); frontend::Program program = CreateElementwiseProgram(); IRSchedule ir_sch = MakeIRSchedule(&program); LOG(INFO) << GetIR(ir_sch); @@ -136,23 +139,72 @@ TEST(ScheduleBlockGraph, elementwise) { #ifdef CINN_WITH_CUDA TEST(ScheduleBlockGraph, reduce) { - frontend::Program program = CreateReduceProgram(); + if (FLAGS_cinn_new_group_scheduler) { + Context::Global().ResetNameId(); + frontend::Program program = CreateReduceProgram(); + IRSchedule ir_sch = MakeIRSchedule(&program); + ScheduleBlockGraph sbg(ir_sch); + LOG(INFO) << GetIR(ir_sch); + LOG(INFO) << sbg.Visualize(); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 5); + CHECK_EQ(sbg.nodes().size(), 5); + + ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_2__reduce_init"); + CHECK(v_reduce_init); + CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0); + CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3); + + ScheduleBlockNode* v = sbg.RetrieveNode("var_2"); + CHECK(v); + CHECK_EQ(v->UpstreamNodes().size(), 2); + CHECK_EQ(v->DownstreamNodes().size(), 2); + + std::vector reverse_dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { + reverse_dfs_topo_order_ids.push_back(node->id()); + }); + for (const std::string& id : reverse_dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 5); + + std::vector dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&dfs_topo_order_ids](const ScheduleBlockNode* node) { + dfs_topo_order_ids.push_back(node->id()); + }, + false); + for (const std::string& id : dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(dfs_topo_order_ids.size(), 5); + } +} + +TEST(ScheduleBlockGraph, arg_max) { + Context::Global().ResetNameId(); + frontend::NetBuilder builder("net_builder"); + auto x = builder.CreateInput(Float(32), {8, 16}, "X"); + auto y = builder.Argmax(x, 0); + frontend::Program program = builder.Build(); + IRSchedule ir_sch = MakeIRSchedule(&program); LOG(INFO) << GetIR(ir_sch); ScheduleBlockGraph sbg(ir_sch); LOG(INFO) << sbg.Visualize(); - CHECK_EQ(sbg.BlockIdsInOrder().size(), 8); - CHECK_EQ(sbg.nodes().size(), 8); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 3); + CHECK_EQ(sbg.nodes().size(), 3); - ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_48__reduce_init"); - CHECK(v_reduce_init); - CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0); - CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3); + ScheduleBlockNode* v0_idx = sbg.RetrieveNode("var_0_index"); + CHECK(v0_idx); + CHECK_EQ(v0_idx->UpstreamNodes().size(), 1); + CHECK_EQ(v0_idx->DownstreamNodes().size(), 1); - ScheduleBlockNode* v = sbg.RetrieveNode("var_48"); - CHECK(v); - CHECK_EQ(v->UpstreamNodes().size(), 5); - CHECK_EQ(v->DownstreamNodes().size(), 2); + ScheduleBlockNode* v0 = sbg.RetrieveNode("var_0"); + CHECK(v0); + CHECK_EQ(v0->UpstreamNodes().size(), 2); + CHECK_EQ(v0->DownstreamNodes().size(), 0); std::vector reverse_dfs_topo_order_ids; sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { @@ -161,7 +213,7 @@ TEST(ScheduleBlockGraph, reduce) { for (const std::string& id : reverse_dfs_topo_order_ids) { LOG(INFO) << id; } - CHECK_EQ(reverse_dfs_topo_order_ids.size(), 8); + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 3); std::vector dfs_topo_order_ids; sbg.DFSTopoWalk( @@ -172,7 +224,7 @@ TEST(ScheduleBlockGraph, reduce) { for (const std::string& id : dfs_topo_order_ids) { LOG(INFO) << id; } - CHECK_EQ(dfs_topo_order_ids.size(), 8); + CHECK_EQ(dfs_topo_order_ids.size(), 3); } #endif diff --git a/paddle/cinn/optim/replace_cross_thread_reduction.cc b/paddle/cinn/optim/replace_cross_thread_reduction.cc index 5102e8bc6468ff..83a7bb83498d66 100644 --- a/paddle/cinn/optim/replace_cross_thread_reduction.cc +++ b/paddle/cinn/optim/replace_cross_thread_reduction.cc @@ -29,6 +29,7 @@ namespace cinn { namespace optim { +namespace { struct CrossThreadReductionReplacer : public ir::IRMutator { void operator()(ir::Expr* expr) { Visit(expr); } @@ -148,6 +149,8 @@ struct CrossThreadReductionReplacer : public ir::IRMutator { std::vector cur_loops_; }; +} // namespace + void ReplaceCrossThreadReduction(Expr* e) { CrossThreadReductionReplacer()(e); } } // namespace optim diff --git a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 2aefb019eac731..aef8907b81a431 100644 --- a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -474,11 +474,11 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \ return tmp_val; \ } else { \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \ return tmp_val; \ } \ } diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 3d3801fa675fb0..cad18f4084a5de 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -61,6 +61,10 @@ PD_DEFINE_bool(general_fusion_merge_pass, BoolFromEnv("FLAGS_general_fusion_merge_pass", true), "Whether to use general fusion_merge pass."); +PD_DEFINE_bool(cinn_new_group_scheduler, + BoolFromEnv("FLAGS_cinn_new_group_scheduler", false), + "Whether to use new group scheduler."); + PD_DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false),