From 959120d3c1bb6c651132ae2574b4e49553f734c1 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 19 Aug 2020 20:54:24 +0530 Subject: [PATCH] More test case added --- .../unittest/test_tir_transform_hoist_if.py | 502 +++++++++++++++++- 1 file changed, 501 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 4ca952af00d40..b4ffee702f4b2 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -16,6 +16,9 @@ # under the License. import tvm from tvm import te +from tvm import relay +import numpy as np +from tvm.relay.testing import ctx_list var_list = [] @@ -255,6 +258,487 @@ def test_multi_if(): ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) +def test_no_hoisting_1(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(k >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_2(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + x = te.var("x") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(i >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3 + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_4(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_5(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.if_scope((tx + j) < 9): + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_1(): + n = te.size_var("n") + m = te.size_var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "k") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + s = te.create_schedule(B.op) + ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) + BF = s.rfactor(B, ki) + xo, xi = s[B].split(s[B].op.axis[0], factor=32) + s[B.op].bind(xo, te.thread_axis("blockIdx.x")) + s[B.op].bind(xi, te.thread_axis("threadIdx.y")) + s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) + s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) + func = tvm.driver.build_module.form_irmodule( + s, [A, B], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_2(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + #ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_4(): + nn = 1024 + n = tvm.runtime.convert(nn) + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + AA = te.compute((n,), lambda *i: A(*i), name='A') + BB = te.compute((n,), lambda *i: B(*i), name='B') + T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T') + C = te.compute(A.shape, lambda *i: T(*i), name='C') + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + xo1, xo2 = s[C].split(xo, factor=13) + s[C].parallel(xo2) + s[C].pragma(xo1, "parallel_launch_point") + s[C].pragma(xo2, "parallel_stride_pattern") + s[C].pragma(xo2, "parallel_barrier_when_finish") + s[C].vectorize(xi) + func = tvm.driver.build_module.form_irmodule( + s, [A, B, C], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_5(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + g = te.var('g') + + ib.scope_attr(data, "storage_scope", "global") + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(data[g] < 3): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3 + with ib.else_scope(): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + + stmt = new_stmt + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + n) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + i) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_op_conv(): + dtype = "float32" + dshape = (1, 80, 73, 73) + kshape = (192, 80, 3, 3) + padding=(1, 1) + groups=1 + dilation=(1, 1) + kernel_size=(3, 3) + channels=192 + scale=1 + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d(x, w, padding=padding, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_size) + + func = relay.Function([x, w], y) + mod = tvm.IRModule() + mod['main'] = func + mod = relay.transform.InferType()(mod) + + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + + params = {'w': tvm.nd.array(kernel)} + for target, ctx in ctx_list(): + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t1 = e(data_tvm).results + t1 = np.array(t1) * 1000 + print('{} ms'.format(t1.mean())) + + with tvm.transform.PassContext(opt_level=3, config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t2 = e(data_tvm).results + t2 = np.array(t2) * 1000 + + print('{} ms'.format(t2.mean())) + tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": test_hoist_top_for() @@ -265,4 +749,20 @@ def test_multi_if(): test_nested_for() test_if_block() test_multi_if() - + test_no_hoisting_1() + test_no_hoisting_2() + test_no_hoisting_3() + test_no_hoisting_4() + test_no_hoisting_5() + test_no_hoisting_6() + test_no_hoisting_7() + test_hoisting_block_scope_1() + test_hoisting_block_scope_2() + test_hoisting_block_scope_3() + test_hoisting_block_scope_4() + test_hoisting_block_scope_5() + test_hoisting_block_scope_6() + test_hoisting_block_scope_7() + + # Test with Conv Op + test_hoisting_op_conv()