diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 8f87ef920784..dd01aed61c52 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) { ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { const Var& target_var = match_buffer->buffer->data; - match_buffers_[target_var.get()] = match_buffer; - buffer_var_map_.Set(target_var, match_buffer->buffer); + const Var& source_var = match_buffer->source->buffer->data; + if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) { + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } } StmtExprVisitor::operator()(stmt); } diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index bee11ad72280..59f9170786b6 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - bool is_root = is_root_; - is_root_ = false; Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { @@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator { buffer_data_to_buffer_.Set(buf->data, buf); } } + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + const Var& source_var = match_buffer->source->buffer->data; + ICHECK(buffer_data_to_buffer_.count(source_var)); + buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); + } Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - // Ignore buffer allocated inside the block when getting access region. + // No longer consider buffers created by match_buffer inside the block when updating access + // region. + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + buffer_data_to_buffer_.erase(target_var); + } + // No longer consider buffers allocated inside the block when updating access region. if (it != alloc_buffers_.end()) { for (const Buffer& buf : it->second) { buffer_data_to_buffer_.erase(buf->data); @@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); - // The read/write regions of root block are always empty. - if (!is_root) { - // Recalculate block access region - CollectReadWrite(GetRef(op), &n->reads, &n->writes); - } - + // Erase buffer allocated inside the block from access region. + n->reads = RemoveRedundantBufferRegion(n->reads); + n->writes = RemoveRedundantBufferRegion(n->writes); return Stmt(n); } @@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator { return std::move(realize); } + Array RemoveRedundantBufferRegion(const Array& region) const { + Array result; + for (const BufferRegion& buffer_region : region) { + if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { + result.push_back(buffer_region); + } + } + return result; + } + void CollectReadWrite(const Block& block, Array* reads, - Array* writes) { + Array* writes) const { Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); *reads = access[0]; *writes = access[1]; @@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; - /*! \brief indicate the whether the block is root. */ - bool is_root_{true}; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7641f0ac46cb..9c95b9819e6f 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -114,20 +114,29 @@ def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block block_inner = block.body[0].body.body.block - alloc_buffers = func.body.block.alloc_buffers + alloc_buffers = match_buffer_func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} - # Check inner block AAA - ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) - tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) - tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) - # Check block ret = tir.analysis.get_block_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.writes, ret[1]) # B is opaque access tvm.ir.assert_structural_equal(block.reads, ret[2]) + # Check inner block AAA without updating buffer_var_map + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + # Since AA is not in the buffer_var_map, region of AA will not be collected. + tvm.ir.assert_structural_equal([], ret[1]) + + # Check inner block AAA + for match_buffer in block.match_buffers: + target_buffer = match_buffer.buffer + buffer_var_map[target_buffer.data] = target_buffer + + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + if __name__ == "__main__": test_block_access_region_detector() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 8418e192d060..07140ab458e6 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None: C1[()] = 0 +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + A_cache = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[(v * 128) : ((v * 128) + 128)]]) + tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) + tir.evaluate( + tir.call_extern( + "test", + A_cache.data, + (v * 128), + 128, + A.data, + (v * 128), + 128, + dtype="float32", + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + tir.reads(A[vi * 128 : vi * 128 + 128]) + tir.writes(B[vi * 128 : vi * 128 + 128]) + A_cache = tir.alloc_buffer([1024]) + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([A_cache[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern( + "test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -149,6 +206,10 @@ def test_match_buffer_allocation(): _check(match_buffer_func, transformed_match_buffer_func) +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + def test_lower_te(): x = te.placeholder((1,)) y = te.compute((1,), lambda i: x[i] + 2) @@ -164,4 +225,5 @@ def test_lower_te(): test_elementwise() test_locate_buffer_allocation() test_match_buffer_allocation() + test_opaque_access() test_lower_te()