From 57597f62b44f0ab17adae34b89f5f526b816b759 Mon Sep 17 00:00:00 2001 From: Wei Tao <1136862851@qq.com> Date: Mon, 30 Oct 2023 11:29:30 +0800 Subject: [PATCH] [Fix][TIR]fix symbolic strides lower (#16000) * [Fix][TIR]fix symbolic strides lower * [Fix][TIR] run the black formatter --- src/tir/transforms/ir_utils.cc | 3 +- .../test_tir_transform_lower_opaque_block.py | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 99ed4376590e..25c10dd6828d 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -417,7 +417,8 @@ Array GetBufferAllocationShape(const Buffer& buffer) { if (buffer->strides.size()) { ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { - ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))); + ICHECK( + arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index 444e36bfbb7a..ae44d2127595 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -250,6 +250,50 @@ def transformed_strided_buffer_func( C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) +@T.prim_func +def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: + n = T.int32() + A = T.match_buffer(a, (1, n, 10240)) + padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96)) + # with T.block("root"): + for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): + with T.block(""): + A_pad_shared_dyn = T.alloc_buffer( + (1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn" + ) + for ax0, ax1 in T.grid(96, 64): + with T.block("A_pad_shared.dyn"): + T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) + A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else( + i * 128 + j * 32 + ax0 < n, + A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], + T.float32(0), + ) + + +@T.prim_func +def transformed_symbolic_strided_buffer_func(a: T.handle): + n = T.int32() + A = T.match_buffer(a, (1, n, 10240)) + for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): + A_pad_shared_dyn = T.allocate( + [1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn" + ) + A_pad_shared_dyn_1 = T.decl_buffer( + (1, T.min((n + 63) // 64 * 64, 96), 64), + data=A_pad_shared_dyn, + strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), + scope="shared.dyn", + ) + for ax0, ax1 in T.grid(96, 64): + if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64: + A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else( + i * 128 + j * 32 + ax0 < n, + A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], + T.float32(0), + ) + + @T.prim_func def annotated_loops(a: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") @@ -301,6 +345,10 @@ def test_strided_buffer(): _check(compacted_strided_buffer_func, transformed_strided_buffer_func) +def test_symbolic_strided_buffer(): + _check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func) + + def test_lower_te(): x = te.placeholder((1,)) y = te.compute((1,), lambda i: x[i] + 2)