Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] Relax reorder primitive's affine binding check #10887

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
*/
void CheckAffineBinding(const ScheduleState& self, Block block);

/*!
* \brief Check whether a block has an affine binding under the high exclusive sref node,
* throw an exception if the block does not have an affine binding.
* \param self The schedule state
* \param block The block to be checked
* \param high_exclusive The highest sref node
* \throw ScheduleError If the input block does not have an affine binding
*/
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
const Optional<StmtSRef>& high_exclusive);

/*!
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
Expand Down
50 changes: 43 additions & 7 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
return true;
}

void CheckAffineBinding(const ScheduleState& self, Block block) {
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
const Optional<StmtSRef>& high_exclusive) {
class NotAffineBindingError : public ScheduleError {
public:
explicit NotAffineBindingError(IRModule mod, Block block)
: mod_(std::move(mod)), block_(std::move(block)) {}
explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
: mod_(std::move(mod)), block_(std::move(block)) {
if (high_exclusive.defined()) {
high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
}
}
String FastErrorString() const final {
return "ScheduleError: The block is required to have an affine binding";
std::ostringstream ss;
if (high_exclusive_loop_) {
ss << "ScheduleError: The block is required to have an partial affine binding under "
<< high_exclusive_loop_->loop_var;
} else {
ss << "ScheduleError: The block is required to have an affine binding";
}
return ss.str();
}
String DetailRenderTemplate() const final {
return "The block {0} is required to have an affine binding";
std::ostringstream ss;
if (high_exclusive_loop_) {
ss << "The block {0} is required to have an partial affine binding under "
<< high_exclusive_loop_->loop_var;
} else {
ss << "The block {0} is required to have an affine binding";
}
return ss.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
IRModule mod_;
Block block_;
const ForNode* high_exclusive_loop_{nullptr};
};

if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) {
throw NotAffineBindingError(self->mod, std::move(block));
StmtSRef block_sref = self->stmt2ref.at(block.get());
if (self->IsAffineBlockBinding(block_sref)) {
// check block cached state for global affineness
return;
}
if (block_sref->parent && high_exclusive.defined()) {
// if it is not of global affine binding, check affineness under high_exclusive,
arith::Analyzer analyzer;
Map<Var, Range> dom_map =
LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
return;
}
}
throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
}

void CheckAffineBinding(const ScheduleState& self, Block block) {
CheckPartialAffineBinding(self, std::move(block), NullOpt);
}

Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
Expand Down
21 changes: 13 additions & 8 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,35 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
class BlockPropertyError : public ScheduleError {
public:
/*!
* \brief Check that all the blocks under the specific stmt have affine bindings and only have
* data-parallel or reduction block iters
* \brief Check that all the blocks under the specific stmt have affine bindings
* wrt top loop sref and only have data-parallel or reduction block iters
* \param self The state of the schedule
* \param sref The sref to the specific stmt
*/
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self,
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
const StmtSRefNode* sref) {
class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
public:
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {}
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
const StmtSRefNode* top)
: state_(state), top_(top) {}

private:
void VisitStmt_(const BlockNode* op) final {
for (const IterVar& iter_var : op->iter_vars) {
if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
throw BlockPropertyError(state_->mod, GetRef<Block>(op));
}
CheckAffineBinding(state_, GetRef<Block>(op));
Optional<StmtSRef> high_exclusive =
top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
}
}
const ScheduleState& state_;
const StmtSRefNode* top_;
};

BlockIterTypeAndAffineBindingChecker checker(self);
BlockIterTypeAndAffineBindingChecker checker(self, top);
checker(GetRef<Stmt>(sref->stmt));
}

Expand Down Expand Up @@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
// Step 3. Collect all loops in the chain and check the loops are single-branch
std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
// Step 4. Check the block below has all its block_var to be data-parallel or reduction,
// and the block has an affine binding.
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom);
// and the block has an affine binding wrt top of the loop range.
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
// Step 5. Replace the original loops with the reordered loops and check that outer loop is
// not dependent on inner loop
For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_tir_schedule_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,95 @@ def test_reorder_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_reorder_with_partial_affineness():
@T.prim_func
def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v1, v2 in T.grid(6, 4, 4):
with T.block("block"):
i = T.axis.spatial(14, v0 * 2 + v1)
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0

@T.prim_func
def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v2, v1 in T.grid(6, 4, 4):
with T.block("block"):
i = T.axis.spatial(14, v0 * 2 + v1)
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0

sch = tir.Schedule(non_affine_func, debug_mask="all")
v0, v1, v2 = sch.get_loops(sch.get_block("block"))
with pytest.raises(tvm.tir.ScheduleError):
sch.reorder(v0, v2, v1)

sch.reorder(v2, v1)
tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=non_affine_func)


def test_reorder_with_cascade_tiled_ops():
@T.prim_func
def cascade_pool_ops(
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
) -> None:
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
with T.block("pool_0"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
with T.init():
y1[ax0, ax1, ax2, ax3] = 0.0
y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
with T.block("pool_1"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
with T.init():
y2[ax0, ax1, ax2, ax3] = 0.0
y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]

@T.prim_func
def cascade_pool_ops_tile_reordered(
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
) -> None:
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
for n, c, h_o in T.grid(1, 16, 27):
for w, h_i, kh, kw in T.grid(110, 6, 3, 3):
with T.block("pool_0"):
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(16, c)
ax2 = T.axis.spatial(110, h_o * 4 + h_i)
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
with T.init():
y1[ax0, ax1, ax2, ax3] = 0.0
y1[ax0, ax1, ax2, ax3] = (
y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
)
for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
with T.block("pool_1"):
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(16, c)
ax2 = T.axis.spatial(108, h_o * 4 + h_i)
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
with T.init():
y2[ax0, ax1, ax2, ax3] = 0.0
y2[ax0, ax1, ax2, ax3] = (
y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
)

sch = tvm.tir.schedule.Schedule(cascade_pool_ops)
pool_0 = sch.get_block("pool_0")
pool_1 = sch.get_block("pool_1")
_, _, h, w, _, _ = sch.get_loops(pool_1)
ho, _ = sch.split(h, factors=[None, 4])
sch.compute_at(pool_0, ho)
_, _, _, h_i, w, _, _ = sch.get_loops(pool_0)
sch.reorder(w, h_i)
tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True)
verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops)


def test_reorder_with_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")
Expand Down