Skip to content

Commit

Permalink
[TIR] Bind iter domain in analyzer in CreatePrimFunc (apache#11187)
Browse files Browse the repository at this point in the history
* [TIR] Bind iter domain in analyzer in CreatePrimFunc

* lint

* fix test
  • Loading branch information
vinx13 authored and Boblest Sebastian (ETAS-DEV/XPC-Fe1) committed May 27, 2022
1 parent 202ba46 commit a11ea23
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 31 deletions.
6 changes: 4 additions & 2 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,

const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
iter_var->iter_type, iter_var->thread_tag, iter_var->span));
Range iter_var_dom = Range::FromMinExtent(dom_min, dom_extent);
analyzer->Bind(new_var, iter_var_dom);
iter_vars.push_back(IterVar(iter_var_dom, new_var, iter_var->iter_type, iter_var->thread_tag,
iter_var->span));
}
};
f_push_block_vars(compute_op->axis);
Expand Down
22 changes: 9 additions & 13 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,9 @@ def main( # type: ignore
for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
with T.block("T_layout_trans"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3])
T.reads(placeholder[0, ax4, ax2, ax3])
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore
placeholder[ax0, ax1 * 3 + ax4, ax2, ax3],
T.float32(0),
dtype="float32",
)
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3]


@tvm.script.ir_module
Expand All @@ -82,18 +77,19 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B
for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
with T.block("data_pad"):
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
T.reads(placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1]) # type: ignore
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[0, 0, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
with T.block("conv2d_NCHWc"):
n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore
T.reads(data_pad[0, 0, oh + kh, ow + kw, ic], placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block]) # type: ignore
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]})
with T.init():
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[0, 0, oh + kh, ow + kw, ic] * placeholder_1[oc_chunk, 0, kh, kw, ic, oc_block] # type: ignore


@tvm.script.ir_module
class tvmgen_default_fused_layout_transform_1:
Expand All @@ -106,9 +102,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.
for i0, i1, i2, i3 in T.grid(1, 8, 16, 16):
with T.block("T_layout_trans"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
T.reads(placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore
T_layout_trans[ax0, ax1, ax2, ax3] = placeholder[0, ax1 // 4, ax2, ax3, ax1 % 4] # type: ignore

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
42 changes: 26 additions & 16 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys
import pytest
import numpy as np
import tvm
import tvm.testing
Expand Down Expand Up @@ -524,20 +526,28 @@ def test_int64_indices():
assert loop.extent.dtype == "int64"


def te_reshape():
A = te.placeholder((128, 128), name="A")
B = topi.reshape(A, [8, 16, 128])
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(128, 128), "float32"], T_reshape: T.Buffer[(8, 16, 128), "float32"]
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1, i2 in T.grid(8, 16, 128):
with T.block("T_reshape"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(A[ax0 * 16 + ax1, ax2])
T.writes(T_reshape[ax0, ax1, ax2])
T_reshape[ax0, ax1, ax2] = A[ax0 * 16 + ax1, ax2]


def test_reshape():
_check_workload(te_reshape, tir_reshape)


if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
test_matmul()
test_element_wise()
test_conv2d()
test_multi_output()
test_extern()
test_arg_order()
test_error_reporting()
test_constant()
test_select_simplify()
test_tensor_attr()
test_tensor_layout_attr()
test_argmax_idx_val()
test_argmax_val_idx()
test_int64_indices()
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit a11ea23

Please sign in to comment.