diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d7556ed73995..630a72cedee5 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -31,6 +31,30 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho where A is the only buffer the block consumes, whose indices are distinct atomic variables, and there should not no variables other than the index variables)"; +class HasInitBlock : public ScheduleError { + public: + explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} + + String FastErrorString() const final { return "ScheduleError: The block has init statement"; } + + String DetailRenderTemplate() const final { + return "ScheduleError: The block has init statement: {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + static void Check(const IRModule& mod, const Block& block) { + if (block->init.defined()) { + throw HasInitBlock(mod, block); + } + } + + private: + IRModule mod_; + Block block_; +}; + class NotSingleReadWriteBuffer : public ScheduleError { public: explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) @@ -572,6 +596,7 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); + HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, @@ -616,6 +641,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); + HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // /*require_stage_pipeline=*/true); diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index f8d767da4645..1259219a392a 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -365,6 +365,43 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) +@T.prim_func +def inline_block_with_init( + A: T.Buffer[(1, 512, 7, 7), "float32"], + B: T.Buffer[(1, 512, 1, 1), "float32"], +) -> None: + B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1): + with T.block("tensor_rf"): + vi4 = T.axis.spatial(49, i4) + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(512, i1) + ax2 = T.axis.spatial(1, 0) + ax3 = T.axis.spatial(1, 0) + with T.init(): + B_rf[ax0, ax1, ax2, ax3, vi4] = T.float32(0) + B_rf[ax0, ax1, ax2, ax3, vi4] = ( + B_rf[ax0, ax1, ax2, ax3, vi4] + + A[ + ax0, + ax1, + ax2 * 7 + vi4 // 7, + ax3 * 7 + vi4 % 7, + ] + ) + for i0, i1 in T.grid(1, 512): + for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1): + with T.block("tensor"): + vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1]) + ax1_1 = T.axis.spatial(512, i1 + ax2) + ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0) + B[ax0_1, ax1_1, ax2_1, ax3_1] = ( + B[ax0_1, ax1_1, ax2_1, ax3_1] + B_rf[ax0_1, ax1_1, ax2_1, ax3_1, vi4] + ) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -525,5 +562,12 @@ def test_compute_inline_with_opaque_access(): tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"]) +def test_inline_block_with_init(): + sch = tir.Schedule(inline_block_with_init, debug_mask="all") + block = sch.get_block(name="tensor_rf", func_name="main") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block=block) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))