From 30482dade1b0289175224cda3a37f5813edfc2a0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 29 Apr 2022 18:22:49 -0700 Subject: [PATCH] [TIR] Bind iter domain in analyzer in CreatePrimFunc (#11187) * [TIR] Bind iter domain in analyzer in CreatePrimFunc * lint * fix test --- src/te/operation/create_primfunc.cc | 6 ++- .../unittest/test_meta_schedule_tune_relay.py | 22 ++++------ .../unittest/test_te_create_primfunc.py | 42 ++++++++++++------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 7e7dae855802..af9029dc7a2b 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -142,8 +142,10 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min); const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent); - iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var, - iter_var->iter_type, iter_var->thread_tag, iter_var->span)); + Range iter_var_dom = Range::FromMinExtent(dom_min, dom_extent); + analyzer->Bind(new_var, iter_var_dom); + iter_vars.push_back(IterVar(iter_var_dom, new_var, iter_var->iter_type, iter_var->thread_tag, + iter_var->span)); } }; f_push_block_vars(compute_op->axis); diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 6b45ad6f07a5..23f5ebac2c86 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -60,14 +60,9 @@ def main( # type: ignore for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) + T.reads(placeholder[0, ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( - ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore - placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], - T.float32(0), - dtype="float32", - ) + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3] @tvm.script.ir_module @@ -82,18 +77,19 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) + T.reads(placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1]) # type: ignore T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore + T.reads(data_pad[0, 0, oh + kh, ow + kw, ic], placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[0, 0, oh + kh, ow + kw, ic] * placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block] # type: ignore + @tvm.script.ir_module class tvmgen_default_fused_layout_transform_1: @@ -106,9 +102,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T. for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore + T.reads(placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore + T_layout_trans[ax0, ax1, ax2, ax3] = placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4] # type: ignore # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 014ca71a8112..97cefc6b98db 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest import numpy as np import tvm import tvm.testing @@ -524,20 +526,28 @@ def test_int64_indices(): assert loop.extent.dtype == "int64" +def te_reshape(): + A = te.placeholder((128, 128), name="A") + B = topi.reshape(A, [8, 16, 128]) + return [A, B] + + +@T.prim_func +def tir_reshape( + A: T.Buffer[(128, 128), "float32"], T_reshape: T.Buffer[(8, 16, 128), "float32"] +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2 in T.grid(8, 16, 128): + with T.block("T_reshape"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(A[ax0 * 16 + ax1, ax2]) + T.writes(T_reshape[ax0, ax1, ax2]) + T_reshape[ax0, ax1, ax2] = A[ax0 * 16 + ax1, ax2] + + +def test_reshape(): + _check_workload(te_reshape, tir_reshape) + + if __name__ == "__main__": - test_unique_name_complete_block() - test_unique_name_reduction_block() - test_matmul() - test_element_wise() - test_conv2d() - test_multi_output() - test_extern() - test_arg_order() - test_error_reporting() - test_constant() - test_select_simplify() - test_tensor_attr() - test_tensor_layout_attr() - test_argmax_idx_val() - test_argmax_val_idx() - test_int64_indices() + sys.exit(pytest.main([__file__] + sys.argv[1:]))