Skip to content

Commit

Permalink
[MetaSchedule] Test for Rewrite Parallel-Vectorize-Unroll (apache#513)
Browse files Browse the repository at this point in the history
* rebase

* fix

* fix

* fix
  • Loading branch information
spectrometerHBH authored Nov 12, 2021
1 parent 4e1c53a commit bf1c5df
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 45 deletions.
16 changes: 8 additions & 8 deletions src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,23 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) {
for (const auto& ann : block->annotations) {
if (ann.first == attr::meta_schedule_parallel) {
found = true;
if (const auto* str_imm = ann.second.as<StringImmNode>()) {
parsed->max_parallel_extent = std::atoi(str_imm->value.c_str());
if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
parsed->max_parallel_extent = imm->value;
}
} else if (ann.first == attr::meta_schedule_vectorize) {
found = true;
if (const auto* str_imm = ann.second.as<StringImmNode>()) {
parsed->max_vectorize_extent = std::atoi(str_imm->value.c_str());
if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
parsed->max_vectorize_extent = imm->value;;
}
} else if (ann.first == attr::meta_schedule_unroll_explicit) {
found = true;
if (const auto* str_imm = ann.second.as<StringImmNode>()) {
parsed->unroll_explicit = std::atoi(str_imm->value.c_str());
if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
parsed->unroll_explicit = imm->value;;
}
} else if (ann.first == attr::meta_schedule_unroll_implicit) {
found = true;
if (const auto* str_imm = ann.second.as<StringImmNode>()) {
parsed->unroll_implicit = std::atoi(str_imm->value.c_str());
if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
parsed->unroll_implicit = imm->value;;
}
}
}
Expand Down
37 changes: 0 additions & 37 deletions tests/python/unittest/test_meta_schedule_postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,43 +45,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class Conv_cuda0:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "T.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
threadIdx_y = T.env_thread("threadIdx.y")
blockIdx_x = T.env_thread("blockIdx.x")
blockIdx_y = T.env_thread("blockIdx.y")
blockIdx_z = T.env_thread("blockIdx.z")
A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
T.launch_thread(threadIdx_x, 8)
for ff_c_init, nn_c_init in T.grid(8, 8):
T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
for rc_outer, ry, rx in T.grid(32, 3, 3):
for ax3_inner_outer in T.serial(0, 2):
T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
for rc_inner in T.serial(0, 8):
for ax3 in T.serial(0, 8):
T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
for ff_c, nn_c in T.grid(8, 8):
T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,72 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm.script import tir as T

from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll
from tvm.tir.schedule import Schedule

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
# fmt: off

@tvm.script.ir_module
class Move_PUV:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
# body
with T.block("root"):
T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32})
for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
with T.block("move"):
vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2)
vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2)
vk = T.axis.spatial(1024, k0 * 32 + k1)
T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk]


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable


@T.prim_func
def Move_PUV0(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
# body
with T.block("root"):
for i0_j0_fused in T.parallel(0, 8192):
for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8):
for k1_fused in T.vectorized(0, 32):
with T.block("move"):
vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2)
vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2)
vk = T.axis.spatial(1024, k0 * 32 + k1_fused)
T.where(
(i0_j0_fused // 64 % 128 * 4 + i1) * 4 + i2 < 1024
and (i0_j0_fused % 64 * 4 + j1) * 8 + j2 < 1024
and k0 * 32 + k1_fused % 32 < 1024
)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk]


def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize():
postproc = RewriteParallelVectorizeUnroll()
sch = Schedule(Move_PUV)
assert postproc.apply(sch)
tvm.ir.assert_structural_equal(sch.mod["main"], Move_PUV0)


if __name__ == "__main__":
test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()

0 comments on commit bf1c5df

Please sign in to comment.