Skip to content

Commit

Permalink
[BugFix][TIR] Error check: Inline Block with Init Stmt (#11033)
Browse files Browse the repository at this point in the history
Should fix #10900
  • Loading branch information
junrushao authored Apr 17, 2022
1 parent 8d868f6 commit 9c2df39
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should not no variables other than the index variables)";

class HasInitBlock : public ScheduleError {
public:
explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {}

String FastErrorString() const final { return "ScheduleError: The block has init statement"; }

String DetailRenderTemplate() const final {
return "ScheduleError: The block has init statement: {0}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

static void Check(const IRModule& mod, const Block& block) {
if (block->init.defined()) {
throw HasInitBlock(mod, block);
}
}

private:
IRModule mod_;
Block block_;
};

class NotSingleReadWriteBuffer : public ScheduleError {
public:
explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block)
Expand Down Expand Up @@ -572,6 +596,7 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
bool check_only = false) {
const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref);
Block producer_block = GetRef<Block>(_producer_block);
HasInitBlock::Check(self->mod, producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref,
Expand Down Expand Up @@ -616,6 +641,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
HasInitBlock::Check(self->mod, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true);
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,43 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None
compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))


@T.prim_func
def inline_block_with_init(
A: T.Buffer[(1, 512, 7, 7), "float32"],
B: T.Buffer[(1, 512, 1, 1), "float32"],
) -> None:
B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32")
for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1):
with T.block("tensor_rf"):
vi4 = T.axis.spatial(49, i4)
ax0 = T.axis.spatial(1, 0)
ax1 = T.axis.spatial(512, i1)
ax2 = T.axis.spatial(1, 0)
ax3 = T.axis.spatial(1, 0)
with T.init():
B_rf[ax0, ax1, ax2, ax3, vi4] = T.float32(0)
B_rf[ax0, ax1, ax2, ax3, vi4] = (
B_rf[ax0, ax1, ax2, ax3, vi4]
+ A[
ax0,
ax1,
ax2 * 7 + vi4 // 7,
ax3 * 7 + vi4 % 7,
]
)
for i0, i1 in T.grid(1, 512):
for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1):
with T.block("tensor"):
vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1])
ax1_1 = T.axis.spatial(512, i1 + ax2)
ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4])
with T.init():
B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
B[ax0_1, ax1_1, ax2_1, ax3_1] = (
B[ax0_1, ax1_1, ax2_1, ax3_1] + B_rf[ax0_1, ax1_1, ax2_1, ax3_1, vi4]
)


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -525,5 +562,12 @@ def test_compute_inline_with_opaque_access():
tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"])


def test_inline_block_with_init():
sch = tir.Schedule(inline_block_with_init, debug_mask="all")
block = sch.get_block(name="tensor_rf", func_name="main")
with pytest.raises(tvm.tir.ScheduleError):
sch.compute_inline(block=block)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 9c2df39

Please sign in to comment.