From 1d8177a945d1fb25686d1e6d3862e7da79a183e6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 12 May 2022 18:52:05 -0700 Subject: [PATCH] Prevent simplifing unit IterVar in CreatePrimFunc (#11292) Simplifying unit iter vars in CreatePrimFunc changes semantics of the PrimFunc, which need different handling in analysis. This reverts commit 26cefab5df8f24af7dc43a3239dbfd0e858fd1a2. --- src/te/operation/create_primfunc.cc | 6 +-- .../unittest/test_meta_schedule_tune_relay.py | 22 ++++++---- .../unittest/test_te_create_primfunc.py | 42 +++++++------------ 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index af9029dc7a2b..7e7dae855802 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -142,10 +142,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min); const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent); - Range iter_var_dom = Range::FromMinExtent(dom_min, dom_extent); - analyzer->Bind(new_var, iter_var_dom); - iter_vars.push_back(IterVar(iter_var_dom, new_var, iter_var->iter_type, iter_var->thread_tag, - iter_var->span)); + iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var, + iter_var->iter_type, iter_var->thread_tag, iter_var->span)); } }; f_push_block_vars(compute_op->axis); diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index e154f9ff27b0..e5076af520f3 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -62,9 +62,14 @@ def main( # type: ignore for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[0, ax4, ax2, ax3]) + T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3] + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( + ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore + placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], + T.float32(0), + dtype="float32", + ) @tvm.script.ir_module @@ -79,19 +84,18 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1]) # type: ignore + T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[0, 0, oh + kh, ow + kw, ic], placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block]) # type: ignore + T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[0, 0, oh + kh, ow + kw, ic] * placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block] # type: ignore - + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore @tvm.script.ir_module class tvmgen_default_fused_layout_transform_1: @@ -104,9 +108,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T. for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore + T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4] # type: ignore + T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 97cefc6b98db..014ca71a8112 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys -import pytest import numpy as np import tvm import tvm.testing @@ -526,28 +524,20 @@ def test_int64_indices(): assert loop.extent.dtype == "int64" -def te_reshape(): - A = te.placeholder((128, 128), name="A") - B = topi.reshape(A, [8, 16, 128]) - return [A, B] - - -@T.prim_func -def tir_reshape( - A: T.Buffer[(128, 128), "float32"], T_reshape: T.Buffer[(8, 16, 128), "float32"] -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i0, i1, i2 in T.grid(8, 16, 128): - with T.block("T_reshape"): - ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[ax0 * 16 + ax1, ax2]) - T.writes(T_reshape[ax0, ax1, ax2]) - T_reshape[ax0, ax1, ax2] = A[ax0 * 16 + ax1, ax2] - - -def test_reshape(): - _check_workload(te_reshape, tir_reshape) - - if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_unique_name_complete_block() + test_unique_name_reduction_block() + test_matmul() + test_element_wise() + test_conv2d() + test_multi_output() + test_extern() + test_arg_order() + test_error_reporting() + test_constant() + test_select_simplify() + test_tensor_attr() + test_tensor_layout_attr() + test_argmax_idx_val() + test_argmax_val_idx() + test_int64_indices()