Skip to content

Commit

Permalink
[CINN] Add new group scheduler (PaddlePaddle#56444)
Browse files Browse the repository at this point in the history
* [CINN] Add new group scheduler

* [Fix] Fix priority of bind loops and output node

* [Fix] Set is_reduce_axis for Argmin and Argmax

* [Fix] Add producer consumer relation to Argmax

* [Fix] Add NodePriority and skip ExternCall block

* [Fix] Add prohibit schedule block

* [Fix] schedule block graph test

* [Fix] Skip external calls while auto inline

* [Fix] Fix relationship of block with Call nodes

* [CINN] Use common reduce while loop fusion

* [Fix] ScheduleBlockGraph unittest

* [Fix] reduction unittests

* [Fix] Skip group schedule of NonFusible nodes

* [Fix] Incomplete AutoSimplify

* [Fix] Adapt to new GraphCompilers

* [Fix] schedule block graph unittest

* [Fix] loop reorder to match master

* [Fix] elementwise loop alignment

* [Fix] cuda axis coeff and range

* [Fix] Add conditions to schedules related to cuda

* [Fix] fix conflict

* fix conflict

* [Fix] fix conflict

* Integrate ReductionFactoring

* [CINN] Upgrade ReductionReduction rule

* resolve conflict

* fix tensor in wb-block

* add reduce type in FactorizeReduction

* [CINN] Add cross thread reduction replacer

* Integrate cross thread reduction

* add anonymous namespace

* fix reduction factoring unittest

* Prohibit group schedule on single op

* Revert "Prohibit group schedule on single op"

This reverts commit 13ddff9.

* fix reduction factoring unittest

* fix reduction factoring unittest

* open group scheduler flag

* fix node priority

* fix cross thread reduction on cpu

* fix reduction_factoring with pre Fuse

* Revert "open group scheduler flag"

This reverts commit 192ccc1.

* Revert "fix reduction_factoring with pre Fuse"

This reverts commit 31889eb.

* simplify log of range

* add a TODO

* fix x86 reduction bug
  • Loading branch information
BiynXu authored Nov 2, 2023
1 parent 8d266d5 commit 0ab3175
Show file tree
Hide file tree
Showing 26 changed files with 2,603 additions and 117 deletions.
27 changes: 27 additions & 0 deletions paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule_block_graph.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"

Expand Down Expand Up @@ -94,13 +95,26 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
auto all_loops = ir_schedule->GetLoops(block_name);
CHECK_LE(num_loops_to_bind, all_loops.size())
<< "The number of loops to be bind is greater than size of all_loops";
CHECK_GE(num_loops_to_bind, 0)
<< "The number of loops to be bind should be greater than 0";
// check whether it is the case that threadIdx has been binded but blockIdx
// not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool gpu_thread_has_binded =
num_loops_to_bind < all_loops.size() &&
all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
ir::BlockOrderConstructor block_order_constructor;
std::map<std::vector<int>, ir::Expr> blocks_order_with_ctrl_stmt =
block_order_constructor(&all_loops[num_loops_to_bind - 1]);
for (auto& pair : blocks_order_with_ctrl_stmt) {
if (pair.first.size() == 2) {
ir::Expr stmt = pair.second;
if (stmt.As<ir::For>() && stmt.As<ir::For>()->is_gpu_thread_binded()) {
gpu_thread_has_binded = true;
}
}
}
Expr fused_loop = ir_schedule->Fuse(
{all_loops.begin(), all_loops.begin() + num_loops_to_bind});
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32();
Expand Down Expand Up @@ -181,5 +195,18 @@ std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state,
return {new_state};
}

void AutoBind::Apply(ir::IRSchedule* ir_schedule,
const std::string& block_name) {
int num_loop_can_bind =
CountLoopCanBinded(ir_schedule->GetLoops(block_name)[0].As<ir::For>());
if (num_loop_can_bind > 0) {
BindGPUIndex(ir_schedule,
block_name,
num_loop_can_bind,
kMaxBlocks,
target_->max_num_threads());
}
}

} // namespace auto_schedule
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class AutoBind : public AutoGenRule {
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;

void Apply(ir::IRSchedule* ir_schedule, const std::string& block_name);

private:
std::vector<Expr> applicable_schedule_blocks_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"

Expand All @@ -49,6 +50,11 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);

// Check the schedule block to be inlined is not a reduce tensor.
for (const ir::Var& iter_var : sche_block->iter_vars) {
if (iter_var->is_reduce_axis) {
return false;
}
}
std::set<ir::Expr> find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
Expand All @@ -69,6 +75,29 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
return false;
}

// the xxx_reduce_init block cannot be inlined.
if (ir::IsReduceInitTensorName(tensor->name)) {
return false;
}

// Skip external calls
std::vector<ir::Expr> consumers =
ir::GetConsumers(sche_block_realize_expr, root);
for (const ir::Expr& consumer : consumers) {
std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
consumer.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body,
[&](const ir::Expr* x) {
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
tensor->name;
});
if (find_load.empty()) {
return false;
}
}

// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
if (sche_block->read_buffers.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class AutoInline : public AutoGenRule {
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;

private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); // NOLINT

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,42 @@ void ReductionFactoring::Apply(const std::string& block_name,
// 5. Split the reduction loop into 2 part
VLOG(6) << "before Split: " << ir_schedule->GetModule().GetExprs()[0];
int factor = 1;
int max_factor = 1024;
int extent = ir::GetLoopExtent(fused_reduce_loop);
for (int i = ceil(sqrt(extent)); i >= 1; --i) {
for (int i = max_factor; i >= 1; --i) {
if (extent % i == 0) {
factor = i;
break;
}
}
std::vector<cinn::ir::Expr> splited_reduction_loops =
ir_schedule->Split(fused_reduce_loop, {-1, factor});
ir_schedule->Split(fused_reduce_loop, {factor, -1});
// 6. Apply FactorizeReduction
VLOG(6) << "before FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];
ir_schedule->FactorizeReduction(splited_reduction_loops[0],
num_spatial_loops);
VLOG(6) << "after FactorizeReduction: "
<< ir_schedule->GetModule().GetExprs()[0];

// 7. Loop fusion and cross thread reduction
std::vector<ir::Expr> rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->SimpleComputeAt(rf_block, rb_loops.back());

rb_loops = ir_schedule->GetLoops(block_name);
ir::Expr rf_init_block =
ir_schedule->GetBlock(block_name + "_rf__reduce_init");
ir_schedule->SimpleComputeAt(rf_init_block, rb_loops.back());

if (*target_ == common::DefaultNVGPUTarget()) {
rb_loops = ir_schedule->GetLoops(block_name);
rf_block = ir_schedule->GetBlock(block_name + "_rf");
ir_schedule->Bind(rb_loops.back(), "threadIdx.x");
ir_schedule->SetBuffer(rf_block, "shared");
}
VLOG(6) << "Loop fusion and cross thread reduction: "
<< ir_schedule->GetModule().GetExprs()[0];
}

} // namespace auto_schedule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "paddle/cinn/ir/ir_printer.h"
#include "test/cpp/cinn/concrete_program_builder.h"

PD_DECLARE_bool(cinn_new_group_scheduler);

namespace cinn {
namespace auto_schedule {

Expand All @@ -37,7 +39,9 @@ class TestReductionFactoring : public TestAutoGenRuleBase {
const std::vector<int>& reduce_dim,
const std::string& block_name,
const std::string& expected_ir) {
Initialize(common::DefaultHostTarget());
Initialize(common::DefaultNVGPUTarget());
// In order to forcibly use the most basic Compute of reduction
FLAGS_cinn_new_group_scheduler = 1;
auto test_program = tests::ReduceBuilder().Build(
{{"X", shape}}, {{"reduce_dim", reduce_dim}});
// construct input parameter
Expand Down Expand Up @@ -66,7 +70,8 @@ class TestReductionFactoring : public TestAutoGenRuleBase {
};

TEST_F(TestReductionFactoring, AnalyseApplyType) {
Initialize(common::DefaultHostTarget());
Context::Global().ResetNameId();
Initialize(common::DefaultNVGPUTarget());
auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", {4, 5}}, {"Y", {4, 5}}});
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
Expand All @@ -77,43 +82,44 @@ TEST_F(TestReductionFactoring, AnalyseApplyType) {
RuleApplyType::kCannotApply);
}

#ifdef CINN_WITH_CUDA

TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
serial for (reduce_k_0_0, 0, 8)
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_0, 0, 64)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_0] = 0.00000000f
}
serial for (reduce_k_0_1, 0, 8)
{
ScheduleBlock(var_0_rf)
serial for (reduce_k_0_1, 0, 1)
{
vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1)
var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, ((8 * vreduce_k_0_0) + vreduce_k_0_1)])
ScheduleBlock(var_0_rf)
{
vreduce_k_0_0, i0_0, vreduce_k_0_1 = axis.bind(reduce_k_0_0, i, reduce_k_0_1)
var_0_rf[i0_0, vreduce_k_0_0] = (var_0_rf[i0_0, vreduce_k_0_0] + X[i0_0, (vreduce_k_0_0 + vreduce_k_0_1)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0])
}
}
}
}
}
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
serial for (reduce_k_0_0, 0, 8)
{
ScheduleBlock(var_0)
{
vreduce_k_0_0, i0_0 = axis.bind(reduce_k_0_0, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_0])
}
}
}
Expand All @@ -124,42 +130,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock1ReduceDim) {
}

TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
serial for (reduce_k_0_reduce_k_1_fused, 0, 128)
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_fused] = 0.00000000f
}
serial for (reduce_k_0_reduce_k_1_fused_0, 0, 64)
{
ScheduleBlock(var_0_rf)
serial for (reduce_k_0_reduce_k_1_fused_0, 0, 8)
{
vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((64 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)])
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_fused, i0_0, vreduce_k_0_reduce_k_1_fused_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i, reduce_k_0_reduce_k_1_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused] + X[i0_0, (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) / 128), (((8 * vreduce_k_0_reduce_k_1_fused) + vreduce_k_0_reduce_k_1_fused_0) % 128)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused])
}
}
}
}
}
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
serial for (reduce_k_0_reduce_k_1_fused, 0, 128)
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_fused])
}
}
}
Expand All @@ -170,42 +175,41 @@ TEST_F(TestReductionFactoring, ApplyOnBlock2ReduceDim) {
}

TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) {
Context::Global().ResetNameId();
std::string expected_ir = R"({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512)
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 1024)
{
ScheduleBlock(var_0_rf__reduce_init)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0_rf__reduce_init[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = 0.00000000f
}
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 512)
{
ScheduleBlock(var_0_rf)
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused_0, 0, 256)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((512 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)])
ScheduleBlock(var_0_rf)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i, reduce_k_0_reduce_k_1_reduce_k_2_fused_0)
var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] = (var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused] + X[i0_0, ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) / 64), ((((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) / 64) % 64), (((256 * vreduce_k_0_reduce_k_1_reduce_k_2_fused) + vreduce_k_0_reduce_k_1_reduce_k_2_fused_0) % 64)])
}
}
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused])
}
}
}
}
}
serial for (i, 0, 32)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0 = axis.bind(i)
var_0__reduce_init[i0_0] = 0.00000000f
}
serial for (reduce_k_0_reduce_k_1_reduce_k_2_fused, 0, 512)
{
ScheduleBlock(var_0)
{
vreduce_k_0_reduce_k_1_reduce_k_2_fused, i0_0 = axis.bind(reduce_k_0_reduce_k_1_reduce_k_2_fused, i)
var_0[i0_0] = (var_0[i0_0] + var_0_rf[i0_0, vreduce_k_0_reduce_k_1_reduce_k_2_fused])
}
}
}
Expand All @@ -214,6 +218,7 @@ TEST_F(TestReductionFactoring, ApplyOnBlock3ReduceDim) {
})";
TestApplyOnReduce({32, 64, 64, 64}, {1, 2, 3}, "var_0", expected_ir);
}
#endif

} // namespace auto_schedule
} // namespace cinn
Loading

0 comments on commit 0ab3175

Please sign in to comment.