From bf1c5df9762e75a52fb3ca5d14f58ac1b773c472 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Thu, 11 Nov 2021 20:10:37 -0500 Subject: [PATCH] [MetaSchedule] Test for Rewrite Parallel-Vectorize-Unroll (#513) * rebase * fix * fix * fix --- .../rewrite_parallel_vectorize_unroll.cc | 16 ++--- .../unittest/test_meta_schedule_postproc.py | 37 ---------- ...tproc_rewrite_parallel_vectorize_unroll.py | 69 +++++++++++++++++++ 3 files changed, 77 insertions(+), 45 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 447837d36b7a..34c2684e029e 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -98,23 +98,23 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { for (const auto& ann : block->annotations) { if (ann.first == attr::meta_schedule_parallel) { found = true; - if (const auto* str_imm = ann.second.as()) { - parsed->max_parallel_extent = std::atoi(str_imm->value.c_str()); + if (const auto* imm = ann.second.as()) { + parsed->max_parallel_extent = imm->value; } } else if (ann.first == attr::meta_schedule_vectorize) { found = true; - if (const auto* str_imm = ann.second.as()) { - parsed->max_vectorize_extent = std::atoi(str_imm->value.c_str()); + if (const auto* imm = ann.second.as()) { + parsed->max_vectorize_extent = imm->value;; } } else if (ann.first == attr::meta_schedule_unroll_explicit) { found = true; - if (const auto* str_imm = ann.second.as()) { - parsed->unroll_explicit = std::atoi(str_imm->value.c_str()); + if (const auto* imm = ann.second.as()) { + parsed->unroll_explicit = imm->value;; } } else if (ann.first == attr::meta_schedule_unroll_implicit) { found = true; - if (const auto* str_imm = ann.second.as()) { - parsed->unroll_implicit = std::atoi(str_imm->value.c_str()); + if (const auto* imm = ann.second.as()) { + parsed->unroll_implicit = imm->value;; } } } diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py index a03cbbdbc3c7..6e17e7bac3f2 100644 --- a/tests/python/unittest/test_meta_schedule_postproc.py +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -45,43 +45,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.script.ir_module -class Conv_cuda0: - @T.prim_func - def main(a: T.handle, b: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "T.noalias": True}) - # var definition - threadIdx_x = T.env_thread("threadIdx.x") - threadIdx_y = T.env_thread("threadIdx.y") - blockIdx_x = T.env_thread("blockIdx.x") - blockIdx_y = T.env_thread("blockIdx.y") - blockIdx_z = T.env_thread("blockIdx.z") - A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") - B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") - # body - T.launch_thread(blockIdx_z, 196) - B_local = T.allocate([64], "float32", "local") - Apad_shared = T.allocate([512], "float32", "shared") - Apad_shared_local = T.allocate([8], "float32", "local") - T.launch_thread(blockIdx_y, 8) - T.launch_thread(blockIdx_x, 4) - T.launch_thread(threadIdx_y, 8) - T.launch_thread(threadIdx_x, 8) - for ff_c_init, nn_c_init in T.grid(8, 8): - T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) - for rc_outer, ry, rx in T.grid(32, 3, 3): - for ax3_inner_outer in T.serial(0, 2): - T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) - for rc_inner in T.serial(0, 8): - for ax3 in T.serial(0, 8): - T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) - for ff_c, nn_c in T.grid(8, 8): - T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) - for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): - T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on - - # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index ae60803d08f3..8a215e88597a 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -15,3 +15,72 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Move_PUV: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1) + T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +@T.prim_func +def Move_PUV0(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + for i0_j0_fused in T.parallel(0, 8192): + for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): + for k1_fused in T.vectorized(0, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1_fused) + T.where( + (i0_j0_fused // 64 % 128 * 4 + i1) * 4 + i2 < 1024 + and (i0_j0_fused % 64 * 4 + j1) * 8 + j2 < 1024 + and k0 * 32 + k1_fused % 32 < 1024 + ) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + +def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): + postproc = RewriteParallelVectorizeUnroll() + sch = Schedule(Move_PUV) + assert postproc.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], Move_PUV0) + + +if __name__ == "__main__": + test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()