Skip to content

Commit

Permalink
relax reorder primitive's affineness check (apache#10887)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored and altanh committed Apr 28, 2022
1 parent fbe5f69 commit c4a4f23
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 15 deletions.
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
*/
void CheckAffineBinding(const ScheduleState& self, Block block);

/*!
* \brief Check whether a block has an affine binding under the high exclusive sref node,
* throw an exception if the block does not have an affine binding.
* \param self The schedule state
* \param block The block to be checked
* \param high_exclusive The highest sref node
* \throw ScheduleError If the input block does not have an affine binding
*/
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
const Optional<StmtSRef>& high_exclusive);

/*!
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
Expand Down
50 changes: 43 additions & 7 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
return true;
}

void CheckAffineBinding(const ScheduleState& self, Block block) {
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
const Optional<StmtSRef>& high_exclusive) {
class NotAffineBindingError : public ScheduleError {
public:
explicit NotAffineBindingError(IRModule mod, Block block)
: mod_(std::move(mod)), block_(std::move(block)) {}
explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
: mod_(std::move(mod)), block_(std::move(block)) {
if (high_exclusive.defined()) {
high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
}
}
String FastErrorString() const final {
return "ScheduleError: The block is required to have an affine binding";
std::ostringstream ss;
if (high_exclusive_loop_) {
ss << "ScheduleError: The block is required to have an partial affine binding under "
<< high_exclusive_loop_->loop_var;
} else {
ss << "ScheduleError: The block is required to have an affine binding";
}
return ss.str();
}
String DetailRenderTemplate() const final {
return "The block {0} is required to have an affine binding";
std::ostringstream ss;
if (high_exclusive_loop_) {
ss << "The block {0} is required to have an partial affine binding under "
<< high_exclusive_loop_->loop_var;
} else {
ss << "The block {0} is required to have an affine binding";
}
return ss.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
IRModule mod_;
Block block_;
const ForNode* high_exclusive_loop_{nullptr};
};

if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) {
throw NotAffineBindingError(self->mod, std::move(block));
StmtSRef block_sref = self->stmt2ref.at(block.get());
if (self->IsAffineBlockBinding(block_sref)) {
// check block cached state for global affineness
return;
}
if (block_sref->parent && high_exclusive.defined()) {
// if it is not of global affine binding, check affineness under high_exclusive,
arith::Analyzer analyzer;
Map<Var, Range> dom_map =
LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
return;
}
}
throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
}

void CheckAffineBinding(const ScheduleState& self, Block block) {
CheckPartialAffineBinding(self, std::move(block), NullOpt);
}

Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
Expand Down
21 changes: 13 additions & 8 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,35 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
class BlockPropertyError : public ScheduleError {
public:
/*!
* \brief Check that all the blocks under the specific stmt have affine bindings and only have
* data-parallel or reduction block iters
* \brief Check that all the blocks under the specific stmt have affine bindings
* wrt top loop sref and only have data-parallel or reduction block iters
* \param self The state of the schedule
* \param sref The sref to the specific stmt
*/
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self,
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
const StmtSRefNode* sref) {
class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
public:
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {}
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
const StmtSRefNode* top)
: state_(state), top_(top) {}

private:
void VisitStmt_(const BlockNode* op) final {
for (const IterVar& iter_var : op->iter_vars) {
if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
throw BlockPropertyError(state_->mod, GetRef<Block>(op));
}
CheckAffineBinding(state_, GetRef<Block>(op));
Optional<StmtSRef> high_exclusive =
top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
}
}
const ScheduleState& state_;
const StmtSRefNode* top_;
};

BlockIterTypeAndAffineBindingChecker checker(self);
BlockIterTypeAndAffineBindingChecker checker(self, top);
checker(GetRef<Stmt>(sref->stmt));
}

Expand Down Expand Up @@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
// Step 3. Collect all loops in the chain and check the loops are single-branch
std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
// Step 4. Check the block below has all its block_var to be data-parallel or reduction,
// and the block has an affine binding.
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom);
// and the block has an affine binding wrt top of the loop range.
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
// Step 5. Replace the original loops with the reordered loops and check that outer loop is
// not dependent on inner loop
For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_tir_schedule_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,95 @@ def test_reorder_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_reorder_with_partial_affineness():
@T.prim_func
def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v1, v2 in T.grid(6, 4, 4):
with T.block("block"):
i = T.axis.spatial(14, v0 * 2 + v1)
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0

@T.prim_func
def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v2, v1 in T.grid(6, 4, 4):
with T.block("block"):
i = T.axis.spatial(14, v0 * 2 + v1)
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0

sch = tir.Schedule(non_affine_func, debug_mask="all")
v0, v1, v2 = sch.get_loops(sch.get_block("block"))
with pytest.raises(tvm.tir.ScheduleError):
sch.reorder(v0, v2, v1)

sch.reorder(v2, v1)
tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=non_affine_func)


def test_reorder_with_cascade_tiled_ops():
@T.prim_func
def cascade_pool_ops(
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
) -> None:
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
with T.block("pool_0"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
with T.init():
y1[ax0, ax1, ax2, ax3] = 0.0
y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
with T.block("pool_1"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
with T.init():
y2[ax0, ax1, ax2, ax3] = 0.0
y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]

@T.prim_func
def cascade_pool_ops_tile_reordered(
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
) -> None:
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
for n, c, h_o in T.grid(1, 16, 27):
for w, h_i, kh, kw in T.grid(110, 6, 3, 3):
with T.block("pool_0"):
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(16, c)
ax2 = T.axis.spatial(110, h_o * 4 + h_i)
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
with T.init():
y1[ax0, ax1, ax2, ax3] = 0.0
y1[ax0, ax1, ax2, ax3] = (
y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
)
for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
with T.block("pool_1"):
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(16, c)
ax2 = T.axis.spatial(108, h_o * 4 + h_i)
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
with T.init():
y2[ax0, ax1, ax2, ax3] = 0.0
y2[ax0, ax1, ax2, ax3] = (
y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
)

sch = tvm.tir.schedule.Schedule(cascade_pool_ops)
pool_0 = sch.get_block("pool_0")
pool_1 = sch.get_block("pool_1")
_, _, h, w, _, _ = sch.get_loops(pool_1)
ho, _ = sch.split(h, factors=[None, 4])
sch.compute_at(pool_0, ho)
_, _, _, h_i, w, _, _ = sch.get_loops(pool_0)
sch.reorder(w, h_i)
tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True)
verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops)


def test_reorder_with_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")
Expand Down

0 comments on commit c4a4f23

Please sign in to comment.