Skip to content

Commit

Permalink
add arm nhwc conv2d test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 18, 2022
1 parent 40c14f3 commit eb147f3
Showing 1 changed file with 87 additions and 1 deletion.
88 changes: 87 additions & 1 deletion tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from tvm.script import tir as T
from tvm.tir.stmt_functor import pre_order_visit
from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.relay_integration import extract_task_from_relay
from tvm.te import create_prim_func
from tvm import relay
import numpy as np
from tvm.meta_schedule.tune import Parse


def _make_vars(*args: str) -> List[Var]:
Expand Down Expand Up @@ -251,9 +255,91 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)


def test_get_tensorize_loop_mapping_conv2d_nhwc_arm():
@T.prim_func
def gemm_4x4x4_i8i8i32(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4, 4), offset_factor=1, dtype="int8")
B = T.match_buffer(b, (4, 4), offset_factor=1, dtype="int8")
C = T.match_buffer(c, (4, 4), offset_factor=1, dtype="int8")

with T.block("root"):
T.reads(C[0:4, 0:4], A[0:4, 0:4], B[0:4, 0:4])
T.writes(C[0:4, 0:4])
for i, j, k in T.grid(4, 4, 4):
with T.block("update"):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]

data_shape = (8, 64, 56, 56)
weight_shape = (64, 64, 3, 3)

data_dtype = "int8"
weight_dtype = "int8"
out_dtype = "int32"

data = relay.var("data", shape=data_shape, dtype=data_dtype)
weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype)
out_channel = weight_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=weight_shape[2:],
channels=out_channel,
padding=(1, 1),
strides=(1, 1),
out_dtype=out_dtype,
)

relay_mod = tvm.IRModule.from_expr(conv2d)

data = np.random.randint(low=-127, high=128, size=data_shape).astype("int8")
weight_np = np.random.randint(low=-127, high=128, size=weight_shape).astype("int8")

def convert_conv2d_layout(mod, desired_layouts):
with tvm.transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
return seq(mod)

relay_mod = convert_conv2d_layout(relay_mod, {"nn.conv2d": ["NHWC", "HWIO"]})

params = {"weight": weight_np}

target = "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod"
extracted_tasks = extract_task_from_relay(relay_mod, target, params)

conv2d_tasks = list(
filter(
lambda task: "conv2d" in task.task_name,
extracted_tasks,
)
)

mod = Parse._mod(conv2d_tasks[0].dispatched[0])

s = tvm.tir.Schedule(mod)

block = s.get_block("C")

info = get_tensorize_loop_mapping(s, block, gemm_4x4x4_i8i8i32)

desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

desc_loops = collect_loops(gemm_4x4x4_i8i8i32)

for i in range(3):
assert desc_loops[i] in desc_loop_to_sref

_, i1_5, i2_4, i3_3 = s.get_loops(block)

assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i1_5)
assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i2_4)
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i3_3)


if __name__ == "__main__":
# test_suggest_index_map_simple()
# test_suggest_index_map_bijective()
# test_get_tensorize_loop_mapping_dense_vnni()
# test_get_tensorize_loop_mapping_conv2d_nchwc_vnni()
test_get_tensorize_loop_mapping_matmul_mma()
# test_get_tensorize_loop_mapping_matmul_mma()
test_get_tensorize_loop_mapping_conv2d_nhwc_arm()

0 comments on commit eb147f3

Please sign in to comment.