Skip to content

Commit

Permalink
[TIR][TOPI][x86][CI] Support skylake avx512 (apache#13621)
Browse files Browse the repository at this point in the history
* add skylake-avx512 tests

* extend tests by skylake-avx512

* lint fixes

* fix misprinting

* misprinting fix

* TODOs for further development

* add temporally commented tests for skylake-avx512 due to not implemented shedules and postprocs for it. add TODOs for further check and development

* update int8-acc32 test for vnni and avx512 w/o it

* pylint fix

* once more pylint fix

* fix Feature init for skylake

* fix test

* fix intrin names for assert for skylake

* small fix

* return back fast int8 intrinsic tests

* test connect of dense and batch_matmul to avx512 tensorization

* extend dense_alter_layout on avx512 (currently) instead of VNNI. some renaming vnni to int8 for the sake of clarity

* more renaming vnni to int8 for dense schedule, compute, strategy for the sake of clarity

* update for batch_matmul with avx512

* extend space generator init for avx512. Add Default AVX512 schedule rules

* avx512 dot 16x4 intrin was implemented for MS default schedule rule

* small fix

* update

* pylint fixes

* test workaround for const alloc in tir

* test fix (broadcasting)

* remove excess instructions from dot_product_16x4_u8i8i32_avx512

* pylint fix

* skip asm check for askew weight shapes

* fix pylint

* revert test fix

* set number of args

* test fix

* fix const allocation in tir for avx512 dot 16x4

* fix signature of dot_product_16x4_u8i8i32_avx512

* use script instead of tvm.tir for const allocation

* extend auto tensorize test by skylake-avx512 target

* clean code

* update test_op_level1, resolve TODO

* small update test_op_level2

* update test_op_level10, resolve TODO

* update qnn legalize pass test, resolve TODOs

* pylint fixes

* update ms test for avx512

* update more ms test for avx512

* try to fix i386 CI tests

* fix intrin name for check

* skip test due to model downloading issue

* fix test failure

* use ORT for conv2d check

* lint fix after rebasing

* comment ORT part of test

* extend tests tir schedule analysis and transform for avx512. unify test classes

* extend test tir schedule tensorize for avx512

* extend test meta schedule vnni integration for avx512

* rename test file

* pylint fix

* tag fix

* update test meta schedule trace apply with avx512

* rollback test class unifying in utils

* pylint fixes

* separate TIRs for scheduled conv2d for vnni and avx512

* fix registering issue in test

* update conv+bias onnx model for intermediate test

* fix int16 overflow

* fix int16 overflow for dense test

* update input data for test of dense

* small rollback

* fix misprinting

* fix

* restart CI

* DefaultVNNI was renamed to DefaultLLVM for mutator

* rename test file for the sake of clarity

* DefaultVNNI was renamed to DefaultCPUTensorization for postproc

* remove resolved TODO

* DefaultVNNI and AVX512 for ScheduleRule were unified

* replace code to upstream with initial version

* fix arg type

* lint fix

* small fix

* lint fix

* fix misprinting

* rollback trace apply test for avx512 (reviewer remark)

* fix pylint

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
2 people authored and fzi-peccia committed Mar 27, 2023
1 parent a466614 commit 3ef5a2f
Show file tree
Hide file tree
Showing 26 changed files with 485 additions and 283 deletions.
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ class Mutator : public runtime::ObjectRef {
FApply f_apply, FClone f_clone, FAsString f_as_string);
/*! \brief Create default mutators for LLVM */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM();
/*! \brief Create default mutators for x86 VNNI */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultVNNI();
/*! \brief Create default mutators for CUDA */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA();
/*! \brief Create default mutators for CUDA with TensorCore */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc RewriteLayout();
/*! \brief Create default postprocessors for LLVM */
TVM_DLL static Array<Postproc, void> DefaultLLVM();
/*! \brief Create default postprocessors for x86 VNNI */
TVM_DLL static Array<Postproc, void> DefaultVNNI();
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
/*! \brief Create default postprocessors for CUDA */
TVM_DLL static Array<Postproc, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ class ScheduleRule : public runtime::ObjectRef {

/*! \brief Create default schedule rules for LLVM */
TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
/*! \brief Create default schedule rules for x86 VNNI */
TVM_DLL static Array<ScheduleRule, void> DefaultVNNI();
/*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */
TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
/*! \brief Create default schedule rules for CUDA */
TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
We get our new quantized uint8 tensor - scale * (QA' - zp_a')
Similarly we can convert from int8 to uint8.
Similarly we can convert from uint8 to int8.
Parameters
----------
Expand Down Expand Up @@ -449,6 +449,7 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register("cpu")
def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
# TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
Expand All @@ -457,6 +458,7 @@ def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):

@qnn_dense_legalize.register("cpu")
def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
# TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,28 @@ def _has_vnni():
return False


# check avx512 intrinsic groups for SkyLake X
def _has_slavx512():
# Check LLVM support
llvm_version = tvm.target.codegen.llvm_version_major()
is_llvm_support = llvm_version >= 8
arch = platform.machine()
# Only linux is supported for now.
if arch == "x86_64" and sys.platform.startswith("linux"):
with open("/proc/cpuinfo", "r") as content:
ctx = content.read()
check = (
"avx512f" in ctx
and "avx512cd" in ctx
and "avx512bw" in ctx
and "avx512dq" in ctx
and "avx512vl" in ctx
)
return check and is_llvm_support

return False


requires_arm_dot = Feature("arm_dot", "ARM dot product", run_time_check=_arm_dot_supported)


Expand All @@ -1035,6 +1057,13 @@ def _has_vnni():
)


requires_skylake_avx512 = Feature(
"skylake_avx512",
"x86 SkyLake AVX512",
run_time_check=lambda: _has_slavx512() and _is_intel(),
)


def _cmake_flag_enabled(flag):
flag = tvm.support.libinfo()[flag]

Expand Down
40 changes: 40 additions & 0 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,48 @@ def dot_product_16x4_u8i8i32_vnni(
)


@T.prim_func
def dot_product_16x4_u8i8i32_avx512(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
A_brdcst = T.broadcast(A_i32, 16)
A_u8x64 = T.reinterpret(A_brdcst, dtype="uint8x64")

B_i8x64 = B.vload([0, 0], dtype="int8x64")

Red = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
T.uint32(2),
A_u8x64,
B_i8x64,
dtype="int16x32",
)

C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
T.uint32(2),
Red,
T.int16x32(1),
dtype="int32x16",
)


VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"

TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
)

AVX512_DOT_16x4_INTRIN = "dot_16x4_avx512"

TensorIntrin.register(
AVX512_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_avx512
)
25 changes: 13 additions & 12 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from .. import generic, nn
from ..transform import layout_transform
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
from .dense import dense_vnni_schedule, dense_amx_int8_schedule
from .dense import dense_int8_schedule, dense_amx_int8_schedule
from .injective import schedule_injective_from_existing
from .utils import target_has_vnni, target_has_amx
from .utils import target_has_avx512, target_has_amx


@autotvm.register_topi_compute("batch_matmul_vnni.x86")
@autotvm.register_topi_compute("batch_matmul_int8.x86")
def batch_matmul_int8_compute(cfg, x, y, *_):
"""Compute for uint8 x int8 -> int32 batch_matmul"""
batch, m, k = x.shape
Expand All @@ -39,8 +39,8 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_vnni"}
if target_has_avx512(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
attrs_info = None

Expand All @@ -60,13 +60,14 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
return z


def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using VNNI vpdpbusd instruction"""
def batch_matmul_int8_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using avx512 or lower instructions
including VNNI vpdpbusd instruction if possible"""
# C: The output of batched GEMM
# O: The output of the fused op

# Schedule the GEMM part
s, fused_inner = dense_vnni_schedule(cfg, s, C, O, do_parallel=False)
s, fused_inner = dense_int8_schedule(cfg, s, C, O, do_parallel=False)
# Parallelize over batch
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
Expand Down Expand Up @@ -228,9 +229,9 @@ def _callback(op):
return s


@autotvm.register_topi_schedule("batch_matmul_vnni.x86")
@autotvm.register_topi_schedule("batch_matmul_int8.x86")
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_vnni"""
"""Schedule for batch_matmul_int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

Expand All @@ -239,8 +240,8 @@ def _callback(op):
layout_trans = op.input_tensors[1]
if target_has_amx(mcpu):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_vnni(mcpu):
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_avx512(mcpu):
batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
19 changes: 10 additions & 9 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids
from .tensor_intrin import acc_32x32_int32_sapphirerapids
from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx
from .utils import get_simd_32bit_lanes, target_has_avx512, target_has_amx


def _schedule_dense_pack_template(cfg, s, C, O):
Expand Down Expand Up @@ -302,8 +302,8 @@ def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_vnni(mcpu):
dense_vnni_schedule(cfg, s, op.output(0), outs[0])
elif target_has_avx512(mcpu):
dense_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s
Expand All @@ -315,8 +315,8 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"}
if target_has_avx512(mcpu):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
target_attr = None

Expand All @@ -339,8 +339,9 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
return C


def dense_vnni_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using VNNI vpdpbusd instruction"""
def dense_int8_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using avx512 or lower instructions
including VNNI vpdpbusd instruction if possible"""
# C: The output of GEMM
# O: The output of the fused op
def split_y(out):
Expand All @@ -361,7 +362,7 @@ def split_y(out):

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_16x1x16_uint8_int8_int32_cascadelake()
pc = dot_16x1x16_uint8_int8_int32()
s[C].tensorize(a_xi, pc)

if C == O:
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from .dense import _default_dense_pack_config
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
from .utils import target_has_vnni
from .utils import target_has_amx
from .utils import target_has_avx512, target_has_amx
from .. import nn


def check_inst_applicable(x, y, allow_padding=False):
def check_int8_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu)
# TODO(vvchernov): may be also target_has_avx2 or lower?
simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu)
return (
simd_avai
and "int8" in x.dtype
Expand All @@ -49,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)

if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
if check_int8_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

Expand Down Expand Up @@ -86,10 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
return None


def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
if (
check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True)
check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True)
and arg_types[0].dtype == "int8"
):
x, y = inputs
Expand Down Expand Up @@ -135,12 +135,12 @@ def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
@nn.dense_legalize.register("cpu")
def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 dense for VNNI."""
return vnni_legalize(inputs, arg_types, relay.nn.dense, attrs)
return int8_int8_legalize(inputs, arg_types, relay.nn.dense, attrs)


@nn.batch_matmul_legalize.register("cpu")
def _batch_matmul_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 batch_matmul for VNNI."""
if attrs["transpose_a"] or not attrs["transpose_b"]:
return None
return vnni_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True)
return int8_int8_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True)
2 changes: 0 additions & 2 deletions src/meta_schedule/mutator/mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ Map<Mutator, FloatImm> Mutator::DefaultLLVM() {
{Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}};
}

Map<Mutator, FloatImm> Mutator::DefaultVNNI() { return Mutator::DefaultLLVM(); }

Map<Mutator, FloatImm> Mutator::DefaultCUDA() {
return Map<Mutator, FloatImm>{
{Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Array<Postproc> Postproc::DefaultLLVM() {
};
}

Array<Postproc> Postproc::DefaultVNNI() {
Array<Postproc> Postproc::DefaultCPUTensorization() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true),
Expand Down
6 changes: 4 additions & 2 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
};
}

Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) {
static const Map<String, String> intrins = {{"vnni", "dot_16x4_vnni"},
{"avx512", "dot_16x4_avx512"}};
return {
ScheduleRule::ApplyCustomRule(),
ScheduleRule::InlineConstantScalars(),
Expand All @@ -101,7 +103,7 @@ Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
/*max_jobs_per_core=*/16,
/*max_innermost_factor=*/Integer(64)),
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/"dot_16x4_vnni",
/*intrin_name=*/intrins[type],
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(64),
Expand Down
Loading

0 comments on commit 3ef5a2f

Please sign in to comment.