Skip to content

Commit

Permalink
Prevent simplifing unit IterVar in CreatePrimFunc (apache#11292)
Browse files Browse the repository at this point in the history
Simplifying unit iter vars in CreatePrimFunc changes semantics of the PrimFunc, which need different handling in analysis.

This reverts commit 26cefab.
  • Loading branch information
vinx13 authored and Sergey Shtin committed May 17, 2022
1 parent b8b1c51 commit 1d8177a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 39 deletions.
6 changes: 2 additions & 4 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,

const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
Range iter_var_dom = Range::FromMinExtent(dom_min, dom_extent);
analyzer->Bind(new_var, iter_var_dom);
iter_vars.push_back(IterVar(iter_var_dom, new_var, iter_var->iter_type, iter_var->thread_tag,
iter_var->span));
iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
iter_var->iter_type, iter_var->thread_tag, iter_var->span));
}
};
f_push_block_vars(compute_op->axis);
Expand Down
22 changes: 13 additions & 9 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ def main( # type: ignore
for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
with T.block("T_layout_trans"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[0, ax4, ax2, ax3])
T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3])
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3]
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore
placeholder[ax0, ax1 * 3 + ax4, ax2, ax3],
T.float32(0),
dtype="float32",
)


@tvm.script.ir_module
Expand All @@ -79,19 +84,18 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B
for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
with T.block("data_pad"):
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1]) # type: ignore
T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
with T.block("conv2d_NCHWc"):
n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
T.reads(data_pad[0, 0, oh + kh, ow + kw, ic], placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block]) # type: ignore
T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]})
with T.init():
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[0, 0, oh + kh, ow + kw, ic] * placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block] # type: ignore

conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore

@tvm.script.ir_module
class tvmgen_default_fused_layout_transform_1:
Expand All @@ -104,9 +108,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.
for i0, i1, i2, i3 in T.grid(1, 8, 16, 16):
with T.block("T_layout_trans"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
T_layout_trans[ax0, ax1, ax2, ax3] = placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4] # type: ignore
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
42 changes: 16 additions & 26 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import numpy as np
import tvm
import tvm.testing
Expand Down Expand Up @@ -526,28 +524,20 @@ def test_int64_indices():
assert loop.extent.dtype == "int64"


def te_reshape():
A = te.placeholder((128, 128), name="A")
B = topi.reshape(A, [8, 16, 128])
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(128, 128), "float32"], T_reshape: T.Buffer[(8, 16, 128), "float32"]
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1, i2 in T.grid(8, 16, 128):
with T.block("T_reshape"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(A[ax0 * 16 + ax1, ax2])
T.writes(T_reshape[ax0, ax1, ax2])
T_reshape[ax0, ax1, ax2] = A[ax0 * 16 + ax1, ax2]


def test_reshape():
_check_workload(te_reshape, tir_reshape)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
test_unique_name_complete_block()
test_unique_name_reduction_block()
test_matmul()
test_element_wise()
test_conv2d()
test_multi_output()
test_extern()
test_arg_order()
test_error_reporting()
test_constant()
test_select_simplify()
test_tensor_attr()
test_tensor_layout_attr()
test_argmax_idx_val()
test_argmax_val_idx()
test_int64_indices()

0 comments on commit 1d8177a

Please sign in to comment.