From 3ef5a2f4d6832fae6174d716136aae75a61ec943 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 18 Jan 2023 01:26:28 +0400 Subject: [PATCH] [TIR][TOPI][x86][CI] Support skylake avx512 (#13621) * add skylake-avx512 tests * extend tests by skylake-avx512 * lint fixes * fix misprinting * misprinting fix * TODOs for further development * add temporally commented tests for skylake-avx512 due to not implemented shedules and postprocs for it. add TODOs for further check and development * update int8-acc32 test for vnni and avx512 w/o it * pylint fix * once more pylint fix * fix Feature init for skylake * fix test * fix intrin names for assert for skylake * small fix * return back fast int8 intrinsic tests * test connect of dense and batch_matmul to avx512 tensorization * extend dense_alter_layout on avx512 (currently) instead of VNNI. some renaming vnni to int8 for the sake of clarity * more renaming vnni to int8 for dense schedule, compute, strategy for the sake of clarity * update for batch_matmul with avx512 * extend space generator init for avx512. Add Default AVX512 schedule rules * avx512 dot 16x4 intrin was implemented for MS default schedule rule * small fix * update * pylint fixes * test workaround for const alloc in tir * test fix (broadcasting) * remove excess instructions from dot_product_16x4_u8i8i32_avx512 * pylint fix * skip asm check for askew weight shapes * fix pylint * revert test fix * set number of args * test fix * fix const allocation in tir for avx512 dot 16x4 * fix signature of dot_product_16x4_u8i8i32_avx512 * use script instead of tvm.tir for const allocation * extend auto tensorize test by skylake-avx512 target * clean code * update test_op_level1, resolve TODO * small update test_op_level2 * update test_op_level10, resolve TODO * update qnn legalize pass test, resolve TODOs * pylint fixes * update ms test for avx512 * update more ms test for avx512 * try to fix i386 CI tests * fix intrin name for check * skip test due to model downloading issue * fix test failure * use ORT for conv2d check * lint fix after rebasing * comment ORT part of test * extend tests tir schedule analysis and transform for avx512. unify test classes * extend test tir schedule tensorize for avx512 * extend test meta schedule vnni integration for avx512 * rename test file * pylint fix * tag fix * update test meta schedule trace apply with avx512 * rollback test class unifying in utils * pylint fixes * separate TIRs for scheduled conv2d for vnni and avx512 * fix registering issue in test * update conv+bias onnx model for intermediate test * fix int16 overflow * fix int16 overflow for dense test * update input data for test of dense * small rollback * fix misprinting * fix * restart CI * DefaultVNNI was renamed to DefaultLLVM for mutator * rename test file for the sake of clarity * DefaultVNNI was renamed to DefaultCPUTensorization for postproc * remove resolved TODO * DefaultVNNI and AVX512 for ScheduleRule were unified * replace code to upstream with initial version * fix arg type * lint fix * small fix * lint fix * fix misprinting * rollback trace apply test for avx512 (reviewer remark) * fix pylint Co-authored-by: Valery Chernov --- include/tvm/meta_schedule/mutator.h | 2 - include/tvm/meta_schedule/postproc.h | 4 +- include/tvm/meta_schedule/schedule_rule.h | 4 +- python/tvm/relay/qnn/op/legalizations.py | 4 +- python/tvm/testing/utils.py | 29 ++++ python/tvm/tir/tensor_intrin/x86.py | 40 +++++ python/tvm/topi/x86/batch_matmul.py | 25 +-- python/tvm/topi/x86/dense.py | 19 ++- python/tvm/topi/x86/dense_alter_op.py | 18 +- src/meta_schedule/mutator/mutator.cc | 2 - src/meta_schedule/postproc/postproc.cc | 2 +- .../schedule_rule/schedule_rule.cc | 6 +- .../space_generator/space_generator.cc | 19 ++- tests/python/contrib/test_gemm_acc32_vnni.py | 160 +++++++++--------- .../python/integration/test_auto_tensorize.py | 136 +++++++++------ tests/python/relay/test_op_level1.py | 24 ++- tests/python/relay/test_op_level10.py | 45 +++-- tests/python/relay/test_op_level2.py | 24 ++- tests/python/relay/test_pass_qnn_legalize.py | 26 +-- ... => test_meta_schedule_cpu_dot_product.py} | 62 ++++--- .../test_meta_schedule_relay_integration.py | 19 ++- ..._meta_schedule_schedule_rule_mlt_intrin.py | 23 +-- .../test_meta_schedule_trace_apply.py | 8 +- .../unittest/test_tir_schedule_analysis.py | 15 +- .../unittest/test_tir_schedule_tensorize.py | 14 +- .../unittest/test_tir_schedule_transform.py | 38 +++-- 26 files changed, 485 insertions(+), 283 deletions(-) rename tests/python/unittest/{test_meta_schedule_vnni_integration.py => test_meta_schedule_cpu_dot_product.py} (83%) diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 498b2797ada5..1560c00f3907 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -131,8 +131,6 @@ class Mutator : public runtime::ObjectRef { FApply f_apply, FClone f_clone, FAsString f_as_string); /*! \brief Create default mutators for LLVM */ TVM_DLL static Map DefaultLLVM(); - /*! \brief Create default mutators for x86 VNNI */ - TVM_DLL static Map DefaultVNNI(); /*! \brief Create default mutators for CUDA */ TVM_DLL static Map DefaultCUDA(); /*! \brief Create default mutators for CUDA with TensorCore */ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 06fa086c4bca..85fb9003e87f 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -163,8 +163,8 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Postproc RewriteLayout(); /*! \brief Create default postprocessors for LLVM */ TVM_DLL static Array DefaultLLVM(); - /*! \brief Create default postprocessors for x86 VNNI */ - TVM_DLL static Array DefaultVNNI(); + /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */ + TVM_DLL static Array DefaultCPUTensorization(); /*! \brief Create default postprocessors for CUDA */ TVM_DLL static Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 16202e18bf95..7995d1fceeb6 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -290,8 +290,8 @@ class ScheduleRule : public runtime::ObjectRef { /*! \brief Create default schedule rules for LLVM */ TVM_DLL static Array DefaultLLVM(); - /*! \brief Create default schedule rules for x86 VNNI */ - TVM_DLL static Array DefaultVNNI(); + /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */ + TVM_DLL static Array DefaultX86(const String& type); /*! \brief Create default schedule rules for CUDA */ TVM_DLL static Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 9baabf36a9d8..ef368a016e0c 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -248,7 +248,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): Replacing QA + 128 with QA' and (zp_a + 128) with zp_a' We get our new quantized uint8 tensor - scale * (QA' - zp_a') - Similarly we can convert from int8 to uint8. + Similarly we can convert from uint8 to int8. Parameters ---------- @@ -449,6 +449,7 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register("cpu") def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types): + # TODO(vvchernov): not only VNNI # The VNNI transformations prefer uint8 x int8 datatypes. if is_fast_int8_on_intel(): return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d) @@ -457,6 +458,7 @@ def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types): @qnn_dense_legalize.register("cpu") def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): + # TODO(vvchernov): not only VNNI # The VNNI transformations prefer uint8 x int8 datatypes. if is_fast_int8_on_intel(): return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 899b05440388..19669cd60cf4 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1027,6 +1027,28 @@ def _has_vnni(): return False +# check avx512 intrinsic groups for SkyLake X +def _has_slavx512(): + # Check LLVM support + llvm_version = tvm.target.codegen.llvm_version_major() + is_llvm_support = llvm_version >= 8 + arch = platform.machine() + # Only linux is supported for now. + if arch == "x86_64" and sys.platform.startswith("linux"): + with open("/proc/cpuinfo", "r") as content: + ctx = content.read() + check = ( + "avx512f" in ctx + and "avx512cd" in ctx + and "avx512bw" in ctx + and "avx512dq" in ctx + and "avx512vl" in ctx + ) + return check and is_llvm_support + + return False + + requires_arm_dot = Feature("arm_dot", "ARM dot product", run_time_check=_arm_dot_supported) @@ -1035,6 +1057,13 @@ def _has_vnni(): ) +requires_skylake_avx512 = Feature( + "skylake_avx512", + "x86 SkyLake AVX512", + run_time_check=lambda: _has_slavx512() and _is_intel(), +) + + def _cmake_flag_enabled(flag): flag = tvm.support.libinfo()[flag] diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index d93167f9e614..c527d0d21008 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -67,8 +67,48 @@ def dot_product_16x4_u8i8i32_vnni( ) +@T.prim_func +def dot_product_16x4_u8i8i32_avx512( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((16, 4), "int8", offset_factor=1), + C: T.Buffer((16,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:16], A[0:4], B[0:16, 0:4]) + T.writes(C[0:16]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + A_brdcst = T.broadcast(A_i32, 16) + A_u8x64 = T.reinterpret(A_brdcst, dtype="uint8x64") + + B_i8x64 = B.vload([0, 0], dtype="int8x64") + + Red = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"), + T.uint32(2), + A_u8x64, + B_i8x64, + dtype="int16x32", + ) + + C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"), + T.uint32(2), + Red, + T.int16x32(1), + dtype="int32x16", + ) + + VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" TensorIntrin.register( VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni ) + +AVX512_DOT_16x4_INTRIN = "dot_16x4_avx512" + +TensorIntrin.register( + AVX512_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_avx512 +) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 9f3bc2951524..95408a924f28 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -25,12 +25,12 @@ from .. import generic, nn from ..transform import layout_transform from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline -from .dense import dense_vnni_schedule, dense_amx_int8_schedule +from .dense import dense_int8_schedule, dense_amx_int8_schedule from .injective import schedule_injective_from_existing -from .utils import target_has_vnni, target_has_amx +from .utils import target_has_avx512, target_has_amx -@autotvm.register_topi_compute("batch_matmul_vnni.x86") +@autotvm.register_topi_compute("batch_matmul_int8.x86") def batch_matmul_int8_compute(cfg, x, y, *_): """Compute for uint8 x int8 -> int32 batch_matmul""" batch, m, k = x.shape @@ -39,8 +39,8 @@ def batch_matmul_int8_compute(cfg, x, y, *_): _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") mcpu = tvm.target.Target.current().mcpu - if target_has_vnni(mcpu): - attrs_info = {"schedule_rule": "batch_matmul_vnni"} + if target_has_avx512(mcpu): + attrs_info = {"schedule_rule": "batch_matmul_int8"} else: attrs_info = None @@ -60,13 +60,14 @@ def batch_matmul_int8_compute(cfg, x, y, *_): return z -def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans): - """Schedule batch_matmul compute using VNNI vpdpbusd instruction""" +def batch_matmul_int8_schedule(cfg, s, C, O, layout_trans): + """Schedule batch_matmul compute using avx512 or lower instructions + including VNNI vpdpbusd instruction if possible""" # C: The output of batched GEMM # O: The output of the fused op # Schedule the GEMM part - s, fused_inner = dense_vnni_schedule(cfg, s, C, O, do_parallel=False) + s, fused_inner = dense_int8_schedule(cfg, s, C, O, do_parallel=False) # Parallelize over batch fused = s[O].fuse(O.op.axis[0], fused_inner) s[O].parallel(fused) @@ -228,9 +229,9 @@ def _callback(op): return s -@autotvm.register_topi_schedule("batch_matmul_vnni.x86") +@autotvm.register_topi_schedule("batch_matmul_int8.x86") def schedule_batch_matmul_int8(cfg, outs): - """Schedule for batch_matmul_vnni""" + """Schedule for batch_matmul_int8""" s = te.create_schedule([x.op for x in outs]) mcpu = tvm.target.Target.current().mcpu @@ -239,8 +240,8 @@ def _callback(op): layout_trans = op.input_tensors[1] if target_has_amx(mcpu): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - elif target_has_vnni(mcpu): - batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans) + elif target_has_avx512(mcpu): + batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) traverse_inline(s, outs[0].op, _callback) return s diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index bb99a632811b..b697cf98a625 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -26,10 +26,10 @@ from .. import generic, tag from ..utils import get_const_tuple, traverse_inline -from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from .tensor_intrin import dot_16x1x16_uint8_int8_int32 from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids from .tensor_intrin import acc_32x32_int32_sapphirerapids -from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx +from .utils import get_simd_32bit_lanes, target_has_avx512, target_has_amx def _schedule_dense_pack_template(cfg, s, C, O): @@ -302,8 +302,8 @@ def _callback(op): if "dense_int8" in op.tag: if target_has_amx(mcpu): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - elif target_has_vnni(mcpu): - dense_vnni_schedule(cfg, s, op.output(0), outs[0]) + elif target_has_avx512(mcpu): + dense_int8_schedule(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) return s @@ -315,8 +315,8 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") mcpu = tvm.target.Target.current().mcpu - if target_has_vnni(mcpu): - target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} + if target_has_avx512(mcpu): + target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: target_attr = None @@ -339,8 +339,9 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): return C -def dense_vnni_schedule(cfg, s, C, O, do_parallel=True): - """Schedule dense compute using VNNI vpdpbusd instruction""" +def dense_int8_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using avx512 or lower instructions + including VNNI vpdpbusd instruction if possible""" # C: The output of GEMM # O: The output of the fused op def split_y(out): @@ -361,7 +362,7 @@ def split_y(out): s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) - pc = dot_16x1x16_uint8_int8_int32_cascadelake() + pc = dot_16x1x16_uint8_int8_int32() s[C].tensorize(a_xi, pc) if C == O: diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 2cb46b8291fb..a380b7fc9ff7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -24,14 +24,14 @@ from .dense import _default_dense_pack_config from ..utils import get_const_tuple from ..nn import dense_alter_layout -from .utils import target_has_vnni -from .utils import target_has_amx +from .utils import target_has_avx512, target_has_amx from .. import nn -def check_inst_applicable(x, y, allow_padding=False): +def check_int8_applicable(x, y, allow_padding=False): mcpu = tvm.target.Target.current().mcpu - simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) + # TODO(vvchernov): may be also target_has_avx2 or lower? + simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu) return ( simd_avai and "int8" in x.dtype @@ -49,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": + if check_int8_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": weight_layout = "NC16n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) @@ -86,10 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): return None -def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False): +def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False): """Legalizes s8, s8 -> s32 GEMM op for VNNI.""" if ( - check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True) + check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True) and arg_types[0].dtype == "int8" ): x, y = inputs @@ -135,7 +135,7 @@ def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False): @nn.dense_legalize.register("cpu") def _dense_legalize(attrs, inputs, arg_types): """Legalizes s8, s8 -> s32 dense for VNNI.""" - return vnni_legalize(inputs, arg_types, relay.nn.dense, attrs) + return int8_int8_legalize(inputs, arg_types, relay.nn.dense, attrs) @nn.batch_matmul_legalize.register("cpu") @@ -143,4 +143,4 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): """Legalizes s8, s8 -> s32 batch_matmul for VNNI.""" if attrs["transpose_a"] or not attrs["transpose_b"]: return None - return vnni_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True) + return int8_int8_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True) diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 3cf43e11260e..ddc2d73590f9 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -59,8 +59,6 @@ Map Mutator::DefaultLLVM() { {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; } -Map Mutator::DefaultVNNI() { return Mutator::DefaultLLVM(); } - Map Mutator::DefaultCUDA() { return Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 7730e4372fa9..bcd0cef4dd69 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -59,7 +59,7 @@ Array Postproc::DefaultLLVM() { }; } -Array Postproc::DefaultVNNI() { +Array Postproc::DefaultCPUTensorization() { return Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true), diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 113703272031..e25f0b12210d 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -85,7 +85,9 @@ Array ScheduleRule::DefaultLLVM() { }; } -Array ScheduleRule::DefaultVNNI() { +Array ScheduleRule::DefaultX86(const String& type) { + static const Map intrins = {{"vnni", "dot_16x4_vnni"}, + {"avx512", "dot_16x4_avx512"}}; return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -101,7 +103,7 @@ Array ScheduleRule::DefaultVNNI() { /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/"dot_16x4_vnni", + /*intrin_name=*/intrins[type], /*structure=*/"SSRSRS", /*tile_binds=*/NullOpt, /*max_innermost_factor=*/Integer(64), diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 926f86cc4ff9..2ce8d8fa1103 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -29,6 +29,14 @@ String GetRuleKindFromTarget(const Target& target) { if (target->GetAttr("mcpu") && (*f_check_vnni)(target->GetAttr("mcpu").value())) { return "vnni"; + } else { + static const PackedFunc* f_check_avx512 = + runtime::Registry::Get("tvm.topi.x86.utils.target_has_avx512"); + ICHECK(f_check_avx512 != nullptr) << "The `target_has_avx512` func is not in tvm registry."; + if (target->GetAttr("mcpu") && + (*f_check_avx512)(target->GetAttr("mcpu").value())) { + return "avx512"; + } } return "llvm"; } @@ -73,6 +81,7 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { Array default_sch_rules; Array default_postprocs; Map default_mutator_probs; + // for target with skylake-avx512 if (kind == "llvm") { default_sch_rules = ScheduleRule::DefaultLLVM(); default_postprocs = Postproc::DefaultLLVM(); @@ -90,9 +99,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_postprocs = Postproc::DefaultHexagon(); default_mutator_probs = Mutator::DefaultHexagon(); } else if (kind == "vnni") { - default_sch_rules = ScheduleRule::DefaultVNNI(); - default_postprocs = Postproc::DefaultVNNI(); - default_mutator_probs = Mutator::DefaultVNNI(); + default_sch_rules = ScheduleRule::DefaultX86("vnni"); + default_postprocs = Postproc::DefaultCPUTensorization(); + default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "avx512") { + default_sch_rules = ScheduleRule::DefaultX86("avx512"); + default_postprocs = Postproc::DefaultCPUTensorization(); + default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "c") { default_sch_rules = ScheduleRule::DefaultMicro(); default_postprocs = Postproc::DefaultMicro(); diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py b/tests/python/contrib/test_gemm_acc32_vnni.py index 9cec823cc58a..c01f7758cb45 100644 --- a/tests/python/contrib/test_gemm_acc32_vnni.py +++ b/tests/python/contrib/test_gemm_acc32_vnni.py @@ -14,106 +14,102 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm import tvm.testing from tvm import te import numpy as np -from tvm.topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake from tvm.topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32 -import pytest -@tvm.testing.requires_llvm -@pytest.mark.skip("skip because feature not enabled") -def test_fc_int8_acc32(): - m = 1024 - n = 1024 - k = 1024 - +def verify_fc_int8_acc32(m=1024, n=1024, k=1024, target="llvm -mcpu=cascadelake"): X = te.placeholder((m, k), name="X", dtype="uint8") - W = te.placeholder((n, k), name="W", dtype="int8") + # W = te.placeholder((n, k), name="W", dtype="int8") + + if not tvm.testing.device_enabled(target): + print("skip because %s is not enabled..." % target) + return + + dev = tvm.device(target, 0) + # workaround for Target.current() + with tvm.target.Target(target) as target: + pc = dot_16x1x16_uint8_int8_int32() + + ak = te.reduce_axis((0, k), name="k") + packedW = te.placeholder((n // 16, 16 * (k // 4), 4), name="packedW", dtype="int8") + + t_fc = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packedW[ + tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j % 16, ak % 4 + ].astype("int32"), + axis=ak, + ), + name="F", + ) + t_sch = te.create_schedule(t_fc.op) + a_x, a_y = t_fc.op.axis + (a_k,) = t_fc.op.reduce_axis + + a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16) + a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32) + a_ko, a_ki = t_sch[t_fc].split(a_k, factor=4) + a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=4) + t_sch[t_fc].reorder(a_yo, a_xo, a_xi, a_koo, a_koi, a_yi, a_ki) + + t_sch[t_fc].unroll(a_koi) + t_sch[t_fc].tensorize(a_yi, pc) + + t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic") + t_evaluator = t_func.time_evaluator(t_func.entry_name, dev, number=10) + + # generate the plain data + a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8") + + packW = np.random.uniform(1, 10, size=(n // 16, 16 * (k // 4), 4)).astype("int8") + # This occurs in pre_compute stage + for r_idx in range(n // 16): + for s_idx in range(16 * (k // 4)): + for t_idx in range(4): + packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx % 16][(s_idx // 16) * 4 + t_idx] + + x = tvm.nd.array(a_, dev) + w = tvm.nd.array(packW, dev) + y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev) + result = t_evaluator(x, w, y) peak = 280 print("Peak {} Gops/s".format(peak)) - memory_ops = m * k + n * k + 2 * m * n + # memory_ops = m * k + n * k + 2 * m * n gops_per_mm = 2 * m * n * k + gops_per_sec = gops_per_mm / result.mean / 1e9 + # verify the correctness + tvm.testing.assert_allclose(y.numpy(), np.dot(a_, b_.T), rtol=0) + print( + "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}".format( + result.mean * 1000, gops_per_sec, gops_per_sec / peak + ) + ) + # t_func.export_library("tensorize_acc32.o") + + +@tvm.testing.requires_cascadelake +def test_fc_int8_acc32_vnni(): # For LLVM < 8.0, it shows "'cascadelake' is not a recognized processor for this target # (ignoring processor)" error with the following setting. After LLVM 8.0 is enabled in the # test, we should use cascadelake setting. - def verify(target="llvm -mcpu=cascadelake"): - if not tvm.testing.device_enabled(target): - print("skip because %s is not enabled..." % target) - return - - dev = tvm.device(target, 0) - pc = dot_16x1x16_uint8_int8_int32_cascadelake() - ak = te.reduce_axis((0, k), name="k") - packedW = te.placeholder((n // 16, 16 * (k // 4), 4), name="packedW", dtype="int8") - - t_fc = te.compute( - (m, n), - lambda i, j: te.sum( - X[i, ak].astype("int32") - * packedW[ - tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j % 16, ak % 4 - ].astype("int32"), - axis=ak, - ), - name="F", - ) - t_sch = te.create_schedule(t_fc.op) - a_x, a_y = t_fc.op.axis - (a_k,) = t_fc.op.reduce_axis - - a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16) - a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32) - a_ko, a_ki = t_sch[t_fc].split(a_k, factor=4) - a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=4) - t_sch[t_fc].reorder(a_yo, a_xo, a_xi, a_koo, a_koi, a_yi, a_ki) - - t_sch[t_fc].unroll(a_koi) - t_sch[t_fc].tensorize(a_yi, pc) - - t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic") - t_evaluator = t_func.time_evaluator(t_func.entry_name, dev, number=10) - - # generate the plain data - a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8") - b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8") - - packW = np.random.uniform(1, 10, size=(n // 16, 16 * (k // 4), 4)).astype("int8") - # This occurs in pre_compute stage - for r_idx in range(n // 16): - for s_idx in range(16 * (k // 4)): - for t_idx in range(4): - packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx % 16][ - (s_idx // 16) * 4 + t_idx - ] - - x = tvm.nd.array(a_, dev) - w = tvm.nd.array(packW, dev) - y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev) - result = t_evaluator(x, w, y) - - gops_per_sec = gops_per_mm / result.mean / 1e9 - # verify the correctness - tvm.testing.assert_allclose(y.numpy(), np.dot(a_, b_.T), rtol=0) - print( - "Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}".format( - result.mean * 1000, gops_per_sec, gops_per_sec / peak - ) - ) - t_func.export_library("tensorize_acc32.o") + verify_fc_int8_acc32() - verify() +@tvm.testing.requires_skylake_avx512 +def test_fc_int8_acc32_avx512(): + verify_fc_int8_acc32(target="llvm -mcpu=skylake-avx512") -if __name__ == "__main__": - # The test requires Cascade Lake and newer Intel machines to generate the - # correct AVX512 VNNI instruction. So, disabling the test. - # test_fc_int8_acc32() - pass +if __name__ == "__main__": + test_fc_int8_acc32_vnni() + test_fc_int8_acc32_avx512() diff --git a/tests/python/integration/test_auto_tensorize.py b/tests/python/integration/test_auto_tensorize.py index 572da53b34fd..70b2b875c124 100644 --- a/tests/python/integration/test_auto_tensorize.py +++ b/tests/python/integration/test_auto_tensorize.py @@ -29,52 +29,63 @@ from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN -SCH_RULES_FOR_VNNI = [ - ms.schedule_rule.ApplyCustomRule(), - ms.schedule_rule.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ), - ms.schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), - ms.schedule_rule.MultiLevelTilingWithIntrin( - VNNI_INTRIN, - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=ms.schedule_rule.ReuseType( - req="may", - levels=[1, 2], - scope="global", + +CASCADELAKE_VNNI_TARGET = "llvm -mcpu=cascadelake -num-cores 4" +SKYLAKE_AVX512_TARGET = "llvm -mcpu=skylake-avx512 -num-cores 4" + + +def _get_schedule_rules_for_x86(intrin): + return [ + ms.schedule_rule.ApplyCustomRule(), + ms.schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + ms.schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + ms.schedule_rule.MultiLevelTilingWithIntrin( + intrin, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ms.schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + ms.schedule_rule.MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ms.schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), ), - ), - ms.schedule_rule.MultiLevelTiling( - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=ms.schedule_rule.ReuseType( - req="may", - levels=[1, 2], - scope="global", + ms.schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, ), - ), - ms.schedule_rule.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=64, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - ms.schedule_rule.RandomComputeLocation(), -] + ms.schedule_rule.RandomComputeLocation(), + ] + + +SCH_RULES_FOR_VNNI = _get_schedule_rules_for_x86(VNNI_INTRIN) +SCH_RULES_FOR_AVX512 = _get_schedule_rules_for_x86(AVX512_INTRIN) def _get_sch_rules_for_dp4a(intrin): @@ -177,6 +188,11 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, pos asm = lib.lib.get_source("asm") assert "vpdpbusd" in asm + if "skylake-avx512" in target: + asm = lib.lib.get_source("asm") + assert "pmaddubs" in asm + assert "pmaddw" in asm + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) runtime.set_input("data", data_np) runtime.run() @@ -273,9 +289,12 @@ def _test_bert_int8(relay_mod, params, input_info, target, sch_rules, postprocs) @tvm.testing.requires_cascadelake def test_vnni_dense(): - _test_dense( - "uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, "llvm -mcpu=cascadelake -num-cores 4" - ) + _test_dense("uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, CASCADELAKE_VNNI_TARGET) + + +@tvm.testing.requires_skylake_avx512 +def test_avx512_dense(): + _test_dense("uint8", SCH_RULES_FOR_AVX512, POSTPROCS_FOR_VNNI, SKYLAKE_AVX512_TARGET) @pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") @@ -293,9 +312,12 @@ def test_dp4a_dense(): @tvm.testing.requires_cascadelake def test_vnni_conv2d(): - _test_conv2d( - "uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, "llvm -mcpu=cascadelake -num-cores 4" - ) + _test_conv2d("uint8", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, CASCADELAKE_VNNI_TARGET) + + +@tvm.testing.requires_skylake_avx512 +def test_avx512_conv2d(): + _test_conv2d("uint8", SCH_RULES_FOR_AVX512, POSTPROCS_FOR_VNNI, SKYLAKE_AVX512_TARGET) @pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") @@ -319,12 +341,26 @@ def test_vnni_bert_int8(): relay_mod, params, input_info, - "llvm -mcpu=cascadelake -num-cores 4", + CASCADELAKE_VNNI_TARGET, SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI, ) +@tvm.testing.requires_skylake_avx512 +@pytest.mark.skip("Due to quantized BERT download issue") +def test_avx512_bert_int8(): + relay_mod, params, input_info = load_quantized_bert_base() + _test_bert_int8( + relay_mod, + params, + input_info, + SKYLAKE_AVX512_TARGET, + SCH_RULES_FOR_AVX512, + POSTPROCS_FOR_VNNI, + ) + + @tvm.testing.requires_gpu @pytest.mark.skip("Slow on CI") def test_dp4a_bert_int8(): diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 3bb9918c7c77..0549f4f2fbcc 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -760,9 +760,7 @@ def test_bitserial_dense(): assert yy.checked_type == relay.TensorType((m, 32), "int16") -@tvm.testing.requires_cascadelake -@pytest.mark.parametrize("m,n,k", [(32, 128, 96), (32, 128, 97)]) -def test_dense_vnni(m, n, k): +def dense_x86_test(m, n, k, target="llvm -mcpu=cascadelake", intrins=["vpdpbusd"]): data_shape = (m, k) weight_shape = (n, k) @@ -774,12 +772,14 @@ def test_dense_vnni(m, n, k): out = relay.nn.bias_add(dense, bias) mod = tvm.IRModule.from_expr(out) - target = "llvm -mcpu=cascadelake" with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target) - asm = lib.lib.get_source("asm") - assert "vpdpbusd" in asm + # TODO(vvchernov): needs for avx512 arch, can be extended + if n % 16 == 0 and k % 4 == 0: + asm = lib.lib.get_source("asm") + for intrin in intrins: + assert intrin in asm dev = tvm.device(target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -846,6 +846,18 @@ def test_dense_amx_int8(): np.testing.assert_equal(out, ref) +@tvm.testing.requires_cascadelake +@pytest.mark.parametrize("m,n,k", [(32, 128, 96), (32, 128, 97)]) +def test_dense_vnni(m, n, k): + dense_x86_test(m, n, k) + + +@tvm.testing.requires_skylake_avx512 +@pytest.mark.parametrize("m,n,k", [(32, 128, 96), (32, 128, 97)]) +def test_dense_skylake_avx512(m, n, k): + dense_x86_test(m, n, k, "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_dense_rocm_sdot4(): data_shape = (32, 96) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index cdf4e734842b..ed044989ac18 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -473,16 +473,7 @@ def test_batch_matmul(executor_kind): verify_batch_matmul_with_inputs(executor_kind, x, x, x_np, x_np, (10, 27, 27)) -@tvm.testing.requires_cascadelake -@pytest.mark.parametrize( - "b,m,n,k", - [ - (16, 32, 128, 96), - (16, 32, 128, 97), - (16, 32, 129, 96), - ], -) -def test_batch_matmul_vnni(b, m, n, k): +def batch_matmul_x86_test(b, m, n, k, target="llvm -mcpu=cascadelake", intrins=["vpdpbusd"]): x_shape = (b, m, k) y_shape = (b, n, k) z_shape = (b, m, n) @@ -495,12 +486,14 @@ def test_batch_matmul_vnni(b, m, n, k): out = bmm + z mod = tvm.IRModule.from_expr(out) - target = "llvm -mcpu=cascadelake" with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target) - asm = lib.lib.get_source("asm") - assert "vpdpbusd" in asm + # TODO(vvchernov): needs for avx512 arch, can be extended + if n % 16 == 0 and k % 4 == 0: + asm = lib.lib.get_source("asm") + for intrin in intrins: + assert intrin in asm dev = tvm.device(target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -575,6 +568,32 @@ def test_batch_matmul_amx(b, m, n, k): np.testing.assert_equal(out, ref) +@tvm.testing.requires_cascadelake +@pytest.mark.parametrize( + "b,m,n,k", + [ + (16, 32, 128, 96), + (16, 32, 128, 97), + (16, 32, 129, 96), + ], +) +def test_batch_matmul_vnni(b, m, n, k): + batch_matmul_x86_test(b, m, n, k) + + +@tvm.testing.requires_skylake_avx512 +@pytest.mark.parametrize( + "b,m,n,k", + [ + (16, 32, 128, 96), + (16, 32, 128, 97), + (16, 32, 129, 96), + ], +) +def test_batch_matmul_skylake_avx512(b, m, n, k): + batch_matmul_x86_test(b, m, n, k, "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_batch_matmul_rocm_sdot4(): x_shape = (16, 32, 96) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index ca1adf940029..f7cfc81fb2d3 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1696,7 +1696,7 @@ def fast_int8_intrinsic(self, target): elif "cascadelake" in target: return "vpdpbusd" else: - assert False, "Target should be Skylake or Cascadelake" + assert False, "Target should be Nehalem or core-avx2 or Skylake or Cascadelake" @tvm.testing.fixture def assembly( @@ -2137,7 +2137,7 @@ def get_subgraph(dtype): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) -def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instr): +def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs): def get_conv2d_nchw( d_shape, w_shape, @@ -2168,16 +2168,16 @@ def get_conv2d_nchw( bias = relay.var("bias", shape=bias_shape, dtype="int32") bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") - weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + weight_np = np.random.uniform(-32, 32, size=weight_shape).astype("int8") conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) bias_add = relay.add(conv2d, bias) mod = tvm.IRModule.from_expr(bias_add) if data_dtype == "uint8": - data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + data_np = np.random.uniform(0, 64, size=data_shape).astype("uint8") else: - data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + data_np = np.random.uniform(-32, 32, size=data_shape).astype("int8") params = {"weight": weight_np, "bias": bias_np} @@ -2194,7 +2194,8 @@ def get_conv2d_nchw( ): lib = relay.build(mod, target=target, params=params) - assert dot_product_instr in lib.lib.get_source("asm") + for dot_product_instr in dot_product_instrs: + assert dot_product_instr in lib.lib.get_source("asm") rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -2210,13 +2211,20 @@ def get_conv2d_nchw( @tvm.testing.requires_arm_dot def test_conv2d_int8_alter_dtype_arm(): _test_conv2d_int8_alter_dtype( - "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", "sdot" + "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", ["sdot"] ) @tvm.testing.requires_cascadelake def test_conv2d_int8_alter_dtype_vnni(): - _test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake", "vpdpbusd") + _test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake", ["vpdpbusd"]) + + +@tvm.testing.requires_skylake_avx512 +def test_conv2d_int8_alter_dtype_avx512(): + _test_conv2d_int8_alter_dtype( + "int8", "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"] + ) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index a30cd1e73e3f..c64b30a2128b 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -136,11 +136,12 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check transformations for platforms with fast Int8 support. ############################################################# - # Check that Intel VNNI gets picked up. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): - mod = relay.transform.InferType()(mod) - legalized_mod = relay.qnn.transform.Legalize()(mod) - assert "cast" in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext() + # Check that Intel AVX512 (with or w/o VNNI) gets picked up. + for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + with tvm.target.Target(target): + mod = relay.transform.InferType()(mod) + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert "cast" in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext() # Since same dtype, there should not be any transformation with tvm.target.Target( @@ -167,7 +168,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check transformations for platforms with fast Int8 support. ############################################################# - # Check no transformation for Intel VNNI. + # Check no transformation for Intel AVX512. with tvm.target.Target("llvm -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -229,11 +230,12 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check transformations for platforms with fast Int8 support. ############################################################# - # Check that Intel VNNI gets picked up. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): - mod = relay.transform.InferType()(mod) - legalized_mod = relay.qnn.transform.Legalize()(mod) - assert "cast" in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext() + # Check that Intel AVX512 (with or w/o VNNI) gets picked up. + for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + with tvm.target.Target(target): + mod = relay.transform.InferType()(mod) + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert "cast" in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext() # Since same dtype, there should not be any transformation with tvm.target.Target( @@ -260,7 +262,7 @@ def _get_mod(data_dtype, kernel_dtype): ############################################################# # Check transformations for platforms with fast Int8 support. ############################################################# - # Check no transformation for Intel VNNI. + # Check no transformation for Intel AVX512. with tvm.target.Target("llvm -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_cpu_dot_product.py similarity index 83% rename from tests/python/unittest/test_meta_schedule_vnni_integration.py rename to tests/python/unittest/test_meta_schedule_cpu_dot_product.py index 3bbe916472f5..6dc72d69336f 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_cpu_dot_product.py @@ -28,6 +28,7 @@ from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.analysis import has_block from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", @@ -36,9 +37,9 @@ logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -def _schedule_dense(m: Optional[int], do_tune: bool): +def _schedule_dense(m: Optional[int], do_tune: bool, intrin=VNNI_INTRIN): """Manually schedule a dense block, created from TE compute op via CreatePrimFunc, - using VNNI instruction. + using VNNI or AVX512 instructions. """ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: @@ -47,7 +48,7 @@ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if dense_block is None: assert has_block(sch, "compute") dense_block = sch.get_block("compute") - assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] + assert "dense_int8" in sch.get(dense_block).annotations["schedule_rule"] post_blocks = sch.get_consumers(dense_block) if len(post_blocks) > 0: @@ -90,7 +91,7 @@ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: dec = sch.decompose_reduction(dense_block, a_ko) init_loop = sch.get_loops(dec)[-1] sch.vectorize(init_loop) - sch.tensorize(a_xi, VNNI_INTRIN) + sch.tensorize(a_xi, intrin) return True return schedule_fn @@ -109,10 +110,10 @@ def _relay_dense(m, n, k): out_dtype="int32", ) relay_mod = tvm.IRModule.from_expr(out) - data = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + data = np.random.randint(0, 5, size=(m, k), dtype="uint8") params = { - "weight": np.random.uniform(1, 10, size=(n, k)).astype("int8"), - "bias": np.random.uniform(1, 10, size=(n,)).astype("int32"), + "weight": np.random.randint(0, 5, size=(n, k), dtype="int8"), + "bias": np.random.randint(0, 5, size=(n,), dtype="int32"), } def f_check(lib, dev): @@ -135,10 +136,7 @@ def f_check(lib, dev): return relay_mod, params, f_check -@tvm.testing.requires_cascadelake -def test_vnni_schedule_fn_database(): - m, n, k = 1024, 1024, 1024 - target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") +def schedule_16x4_dense_fn_database(target, intrin, m=1024, n=1024, k=1024): dev = tvm.cpu(0) relay_mod, params, f_check = _relay_dense(m, n, k) @@ -146,6 +144,7 @@ def test_vnni_schedule_fn_database(): _schedule_dense( m=m, do_tune=False, + intrin=intrin, ) ), tvm.transform.PassContext( opt_level=3, @@ -167,21 +166,32 @@ def test_vnni_schedule_fn_database(): @tvm.testing.requires_cascadelake -def test_vnni_schedule_fn_tune(): +def test_vnni_schedule_fn_database(): + target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") + schedule_16x4_dense_fn_database(target, VNNI_INTRIN) + + +@tvm.testing.requires_skylake_avx512 +def test_avx512_schedule_fn_database(): + target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=skylake-avx512 -num-cores=4") + schedule_16x4_dense_fn_database(target, AVX512_INTRIN, 16, 16, 16) + + +def schedule_16x4_dense_fn_tune(target, intrin, m=1024, n=1024, k=1024): # pylint: disable=W0105 """ We can inject and apply a custom TIR scheduling to a TE compute of interest, using the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following - declaration for int8 dense targeting the VNNI instruction. + declaration for int8 dense targeting the VNNI or AVX512 instructions. C = te.compute( ... - attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, + attrs={"schedule_rule": "meta_schedule.x86.dense_int8"}, ) When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, it looks up the packed func registry for a function that is associated with the given schedule - rule key ("meta_schedule.x86.dense_vnni" in this example). The signature of such custom + rule key ("meta_schedule.x86.dense_int8" in this example). The signature of such custom schedule functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. @@ -191,14 +201,12 @@ def test_vnni_schedule_fn_tune(): The relevant code is in `src/meta_schedule/space_generator/apply_custom_rule.cc`. """ - def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): - _schedule_dense(m=None, do_tune=True)(sch, dense_block) + def schedule_rule_dense_16x4(sch: Schedule, dense_block: BlockRV): + _schedule_dense(m=None, do_tune=True, intrin=intrin)(sch, dense_block) return [sch] - register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni) + register_func("meta_schedule.x86.dense_int8", schedule_rule_dense_16x4, override=True) - m, n, k = 1024, 1024, 1024 - target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") dev = tvm.cpu(0) relay_mod, params, f_check = _relay_dense(m, n, k) @@ -247,6 +255,20 @@ def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): f_check(lib, dev) +@tvm.testing.requires_cascadelake +def test_vnni_schedule_fn_tune(): + target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") + schedule_16x4_dense_fn_tune(target, VNNI_INTRIN) + + +@tvm.testing.requires_skylake_avx512 +def test_avx512_schedule_fn_tune(): + target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=skylake-avx512 -num-cores=4") + schedule_16x4_dense_fn_tune(target, AVX512_INTRIN, 16, 16, 16) + + if __name__ == """__main__""": test_vnni_schedule_fn_database() + test_avx512_schedule_fn_database() test_vnni_schedule_fn_tune() + test_avx512_schedule_fn_tune() diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index d3731cfa1be8..795890de083e 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -316,9 +316,8 @@ def traverse(t): assert t.task_name in expected_task_names, t.task_name -@pytest.mark.skip("Too slow on CI") -def extract_task_qbert(): - def _test(mod, params, target): +def extract_task_qbert(target, sch_rule_tag): + def _test(mod, params, target, sch_rule_tag): extracted_tasks = ms.relay_integration.extract_tasks(mod, target, params) tune_tasks = list( filter( @@ -341,10 +340,20 @@ def _test(mod, params, target): annotations = sch.get(block).annotations assert "schedule_rule" in annotations - assert "vnni" in annotations["schedule_rule"] + assert sch_rule_tag in annotations["schedule_rule"] mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) - _test(mod, params, target="llvm -mcpu=cascadelake") + _test(mod, params, target=target, sch_rule_tag=sch_rule_tag) + + +@pytest.mark.skip("Too slow on CI") +def extract_task_qbert_vnni(): + extract_task_qbert("llvm -mcpu=cascadelake", "vnni") + + +@pytest.mark.skip("Too slow on CI") +def extract_task_qbert_avx512(): + extract_task_qbert("llvm -mcpu=skylake-avx512", "avx512") @tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old") diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py index 54f342c3a5d8..4667626f1706 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -26,9 +26,10 @@ from tvm.target import Target from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN -def test_vnni_conv2d_nchwc(): +def test_x86_conv2d_nchwc(intrin=VNNI_INTRIN, target="llvm -mcpu=cascadelake -num-cores=4"): @T.prim_func def conv2d_nchwc( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], @@ -68,7 +69,7 @@ def conv2d_nchwc( # fmt: off @T.prim_func - def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + def x86_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): @@ -86,7 +87,7 @@ def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + T.block_attr({"meta_schedule.auto_tensorize":intrin}) with T.init(): for i4_1 in T.serial(16): with T.block("conv2d_NCHWc_int8_init"): @@ -113,7 +114,7 @@ def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4] @T.prim_func - def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + def x86_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): @@ -131,7 +132,7 @@ def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + T.block_attr({"meta_schedule.auto_tensorize":intrin}) with T.init(): for i4_1 in T.serial(16): with T.block("conv2d_NCHWc_int8_init"): @@ -158,7 +159,7 @@ def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4] @T.prim_func - def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + def x86_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_NCHWc_int8_o"): @@ -174,7 +175,7 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0) T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) - T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + T.block_attr({"meta_schedule.auto_tensorize":intrin}) with T.init(): for i4_1 in T.serial(16): with T.block("conv2d_NCHWc_int8_init"): @@ -228,7 +229,6 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac ] mod = conv2d_nchwc - target = Target("llvm -mcpu=cascadelake -num-cores=4") actual = generate_design_space( kind="llvm", mod=mod, @@ -236,7 +236,7 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac types=None, sch_rules=[ ms.schedule_rule.MultiLevelTilingWithIntrin( - VNNI_INTRIN, + intrin, structure="SSRSRS", tile_binds=None, max_innermost_factor=64, @@ -249,7 +249,7 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac check_sketches( mod, sketches=actual, - expected_mods=[vnni_conv2d_nchwc_0, vnni_conv2d_nchwc_1, vnni_conv2d_nchwc_2], + expected_mods=[x86_conv2d_nchwc_0, x86_conv2d_nchwc_1, x86_conv2d_nchwc_2], expected_decisions=[decision_0, decision_1, decision_2], ) @@ -417,7 +417,8 @@ def test_dp4a_dense_no_tensorize_2(): if __name__ == "__main__": - test_vnni_conv2d_nchwc() + test_x86_conv2d_nchwc() + test_x86_conv2d_nchwc(AVX512_INTRIN, "llvm -mcpu=skylake-avx512 -num-cores=4") test_dp4a_dense() test_dp4a_dense_no_tensorize_1() test_dp4a_dense_no_tensorize_2() diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 9a62207fa261..43b9eb8bbb19 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -25,6 +25,8 @@ from tvm.target import Target from tvm.target.codegen import llvm_lookup_intrinsic_id +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN + # fmt: off @tvm.script.ir_module @@ -2553,9 +2555,7 @@ def apply_trace(sch): l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b1) sch.reorder(l42, l43, l44, l45, l46, l35, l33) b48 = sch.blockize(loop=l35) - sch.annotate( - block_or_loop=b48, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni" - ) + sch.annotate(block_or_loop=b48, ann_key="meta_schedule.auto_tensorize", ann_val=VNNI_INTRIN) l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b48) v59, v60, v61, v62 = sch.sample_perfect_tile( loop=l49, n=4, max_innermost_factor=64, decision=[1, 1, 1, 1] @@ -2729,7 +2729,7 @@ def apply_trace(sch): sch.vectorize(loop=l193) b194 = sch.get_block(name="conv2d_NCHWc_int8_o_update", func_name="main") sch.unannotate(block_or_loop=b194, ann_key="meta_schedule.auto_tensorize") - sch.tensorize(block_or_loop=b194, tensor_intrin="dot_16x4_vnni") + sch.tensorize(block_or_loop=b194, tensor_intrin=VNNI_INTRIN) vnni_id = llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512") verify( diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index e0667da6fe92..38bd4bba1418 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -146,7 +146,7 @@ def test_suggest_index_map_winograd(): @tvm.script.ir_module -class DenseVNNIModule: +class DenseTIRModule: @T.prim_func def main( placeholder: T.Buffer[(1024, 1024), "uint8"], @@ -170,7 +170,7 @@ def main( @tvm.script.ir_module -class Conv2dNCHWcVNNIModule: +class Conv2dNCHWcTIRModule: @T.prim_func def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], @@ -202,7 +202,8 @@ def main( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( - placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + "int32", ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", @@ -222,8 +223,8 @@ def callback(node): return loops -def test_get_tensorize_loop_mapping_dense_vnni(): - s = Schedule(DenseVNNIModule) +def test_get_tensorize_loop_mapping_dense_16x4(): + s = Schedule(DenseTIRModule) block = s.get_block("compute") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) @@ -240,8 +241,8 @@ def test_get_tensorize_loop_mapping_dense_vnni(): assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k) -def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): - s = Schedule(Conv2dNCHWcVNNIModule) +def test_get_tensorize_loop_mapping_conv2d_nchwc_16x4(): + s = Schedule(Conv2dNCHWcTIRModule) block = s.get_block("conv2d_NCHWc_int8") info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index fc0bdc146c88..4847f261a32c 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -29,7 +29,7 @@ ARM_DOT_4x4_i8_SDOT_INTRIN, ) from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN, AVX512_DOT_16x4_INTRIN from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN, VDMPY_i16i16i32_INTRIN # fmt: off @@ -557,7 +557,7 @@ def get_matmul_packed(m, n, k, lhs_type, rhs_dtype="int8"): return te.create_prim_func([X, W, matmul]) -def test_tensorize_vnni(): +def tensorize_16x4_test(intrin=VNNI_DOT_16x4_INTRIN): m, n, k = 128, 128, 128 func = get_matmul_packed(m, n, k, "uint8") @@ -572,11 +572,19 @@ def test_tensorize_vnni(): sch.reorder(ko, ji, ki) sch.decompose_reduction(block, ko) - sch.tensorize(ji, VNNI_DOT_16x4_INTRIN) + sch.tensorize(ji, intrin) verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_vnni(): + tensorize_16x4_test() + + +def test_tensorize_avx512(): + tensorize_16x4_test(AVX512_DOT_16x4_INTRIN) + + def test_tensorize_arm_dot(): m, n, k = 128, 128, 128 diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py index e812587e6676..c068385f0a46 100644 --- a/tests/python/unittest/test_tir_schedule_transform.py +++ b/tests/python/unittest/test_tir_schedule_transform.py @@ -18,11 +18,11 @@ from tvm.script import tir as T from tvm.tir import Schedule from tvm.tir.schedule.transform import tile_with_tensor_intrin -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN, AVX512_DOT_16x4_INTRIN @tvm.script.ir_module -class DenseVNNIModule: +class DenseTIRModule: @T.prim_func def main( placeholder: T.Buffer[(1024, 1024), "uint8"], @@ -46,7 +46,7 @@ def main( @tvm.script.ir_module -class DenseVNNIModuleTiled: +class DenseTIRModuleTiled: @T.prim_func def main( placeholder: T.Buffer[(1024, 1024), "uint8"], @@ -72,7 +72,7 @@ def main( @tvm.script.ir_module -class Conv2dNCHWcVNNIModule: +class Conv2dNCHWcTIRModule: @T.prim_func def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], @@ -104,7 +104,8 @@ def main( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( - placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + "int32", ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", @@ -112,7 +113,7 @@ def main( @tvm.script.ir_module -class Conv2dNCHWcVNNIModuleTiled: +class Conv2dNCHWcTIRModuleTiled: @T.prim_func def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], @@ -141,35 +142,38 @@ def main( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( - placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + "int32", ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", ) -def test_tile_with_tensor_intrin_dense_vnni(): - s = Schedule(DenseVNNIModule) +def test_tile_with_tensor_intrin_dense(intrin=VNNI_DOT_16x4_INTRIN): + s = Schedule(DenseTIRModule) block = s.get_block("compute") - tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + tiled_loop = tile_with_tensor_intrin(s, block, intrin) _, _, _, i1_1, _ = s.get_loops(block) assert s.get(tiled_loop) == s.get(i1_1) - tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled) + tvm.ir.assert_structural_equal(s.mod, DenseTIRModuleTiled) -def test_tile_with_tensor_intrin_conv2d_nchwc_vnni(): - s = Schedule(Conv2dNCHWcVNNIModule) +def test_tile_with_tensor_intrin_conv2d_nchwc(intrin=VNNI_DOT_16x4_INTRIN): + s = Schedule(Conv2dNCHWcTIRModule) block = s.get_block("conv2d_NCHWc_int8") - tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN) + tiled_loop = tile_with_tensor_intrin(s, block, intrin) tiled_loops = s.get_loops(block) assert len(tiled_loops) == 12 assert s.get(tiled_loop) == s.get(tiled_loops[-2]) - tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled) + tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcTIRModuleTiled) if __name__ == "__main__": - test_tile_with_tensor_intrin_dense_vnni() - test_tile_with_tensor_intrin_conv2d_nchwc_vnni() + test_tile_with_tensor_intrin_dense() + test_tile_with_tensor_intrin_dense(AVX512_DOT_16x4_INTRIN) + test_tile_with_tensor_intrin_conv2d_nchwc() + test_tile_with_tensor_intrin_conv2d_nchwc(AVX512_DOT_16x4_INTRIN)