Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][TOPI][x86][CI] Support skylake avx512 #13621

Merged
merged 84 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
ef5d6c8
add skylake-avx512 tests
Dec 14, 2022
7c84e41
extend tests by skylake-avx512
Dec 15, 2022
a1c3b3c
lint fixes
Dec 15, 2022
57fbc5e
fix misprinting
Dec 15, 2022
094fd8d
misprinting fix
Dec 19, 2022
b7dff8f
TODOs for further development
Dec 19, 2022
78d5e25
add temporally commented tests for skylake-avx512 due to not implemen…
Dec 19, 2022
df49fe6
update int8-acc32 test for vnni and avx512 w/o it
Dec 20, 2022
d2e02e8
pylint fix
Dec 20, 2022
0610ea6
once more pylint fix
Dec 20, 2022
751f729
fix Feature init for skylake
Dec 20, 2022
3647631
fix test
Dec 20, 2022
a439103
fix intrin names for assert for skylake
Dec 20, 2022
a2f1587
small fix
Dec 20, 2022
877edd2
return back fast int8 intrinsic tests
Dec 20, 2022
f25254f
test connect of dense and batch_matmul to avx512 tensorization
Dec 21, 2022
d4d8bc3
extend dense_alter_layout on avx512 (currently) instead of VNNI. some…
Dec 21, 2022
7266380
more renaming vnni to int8 for dense schedule, compute, strategy for …
Dec 21, 2022
0a393b8
update for batch_matmul with avx512
Dec 21, 2022
2029f83
extend space generator init for avx512. Add Default AVX512 schedule r…
Dec 22, 2022
410c87b
avx512 dot 16x4 intrin was implemented for MS default schedule rule
Dec 22, 2022
a23198f
small fix
Dec 22, 2022
1fa84f4
update
Dec 22, 2022
c3c15d2
pylint fixes
Dec 23, 2022
582caa9
test workaround for const alloc in tir
Dec 23, 2022
6279ad8
test fix (broadcasting)
Dec 23, 2022
5d012fe
remove excess instructions from dot_product_16x4_u8i8i32_avx512
Dec 23, 2022
db282f0
pylint fix
Dec 23, 2022
40f8211
skip asm check for askew weight shapes
Dec 23, 2022
5d393e5
fix pylint
Dec 23, 2022
bd9fd2e
revert test fix
Dec 23, 2022
07666ec
set number of args
Dec 23, 2022
76a5e7e
test fix
Dec 23, 2022
c6548db
fix const allocation in tir for avx512 dot 16x4
Dec 23, 2022
b1889df
fix signature of dot_product_16x4_u8i8i32_avx512
Dec 26, 2022
0b890e9
use script instead of tvm.tir for const allocation
Dec 26, 2022
a49963d
extend auto tensorize test by skylake-avx512 target
Dec 29, 2022
9bd9df1
clean code
Dec 29, 2022
dbca309
update test_op_level1, resolve TODO
Dec 29, 2022
0e78ee8
small update test_op_level2
Dec 29, 2022
99e8d46
update test_op_level10, resolve TODO
Dec 29, 2022
bfe7424
update qnn legalize pass test, resolve TODOs
Dec 29, 2022
4211c0f
pylint fixes
Dec 29, 2022
bd24052
update ms test for avx512
Dec 29, 2022
52abeb0
update more ms test for avx512
Dec 29, 2022
315a947
try to fix i386 CI tests
Dec 29, 2022
b3f4749
fix intrin name for check
Dec 29, 2022
8fc39dc
skip test due to model downloading issue
Dec 29, 2022
eb97d6d
fix test failure
Dec 29, 2022
3f36476
use ORT for conv2d check
Dec 30, 2022
68fb495
lint fix after rebasing
Jan 7, 2023
61271df
comment ORT part of test
Jan 7, 2023
c5a88a1
extend tests tir schedule analysis and transform for avx512. unify te…
Jan 8, 2023
f887555
extend test tir schedule tensorize for avx512
Jan 8, 2023
bfdb2c2
extend test meta schedule vnni integration for avx512
Jan 8, 2023
de26b94
rename test file
Jan 8, 2023
5059fd6
pylint fix
Jan 8, 2023
206d458
tag fix
Jan 8, 2023
e66b0e5
update test meta schedule trace apply with avx512
Jan 8, 2023
b56bb45
rollback test class unifying in utils
Jan 8, 2023
8255f1f
pylint fixes
Jan 8, 2023
fea930c
separate TIRs for scheduled conv2d for vnni and avx512
Jan 9, 2023
9e5f6ee
fix registering issue in test
Jan 9, 2023
ea093b0
update conv+bias onnx model for intermediate test
Jan 10, 2023
293033d
fix int16 overflow
Jan 11, 2023
915fad7
fix int16 overflow for dense test
Jan 11, 2023
a289d4b
update input data for test of dense
Jan 11, 2023
e6ea691
small rollback
Jan 12, 2023
300c66d
fix misprinting
Jan 12, 2023
59bf956
fix
Jan 12, 2023
c550370
restart CI
Jan 15, 2023
b69946a
DefaultVNNI was renamed to DefaultLLVM for mutator
Jan 17, 2023
9c6054b
rename test file for the sake of clarity
Jan 17, 2023
76d9aff
DefaultVNNI was renamed to DefaultCPUTensorization for postproc
Jan 17, 2023
db7960e
remove resolved TODO
Jan 17, 2023
b7d3f8b
DefaultVNNI and AVX512 for ScheduleRule were unified
Jan 17, 2023
8c9e403
replace code to upstream with initial version
Jan 17, 2023
f8794c9
fix arg type
Jan 17, 2023
cddd1a1
lint fix
Jan 17, 2023
84b780d
small fix
Jan 17, 2023
2d772f6
lint fix
Jan 17, 2023
d2343ab
fix misprinting
Jan 17, 2023
18c3610
rollback trace apply test for avx512 (reviewer remark)
Jan 17, 2023
c06ee20
fix pylint
Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ class Mutator : public runtime::ObjectRef {
FApply f_apply, FClone f_clone, FAsString f_as_string);
/*! \brief Create default mutators for LLVM */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM();
/*! \brief Create default mutators for x86 VNNI */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultVNNI();
/*! \brief Create default mutators for CUDA */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA();
/*! \brief Create default mutators for CUDA with TensorCore */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc RewriteLayout();
/*! \brief Create default postprocessors for LLVM */
TVM_DLL static Array<Postproc, void> DefaultLLVM();
/*! \brief Create default postprocessors for x86 VNNI */
TVM_DLL static Array<Postproc, void> DefaultVNNI();
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
/*! \brief Create default postprocessors for CUDA */
TVM_DLL static Array<Postproc, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ class ScheduleRule : public runtime::ObjectRef {

/*! \brief Create default schedule rules for LLVM */
TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
/*! \brief Create default schedule rules for x86 VNNI */
TVM_DLL static Array<ScheduleRule, void> DefaultVNNI();
/*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */
TVM_DLL static Array<ScheduleRule, void> DefaultX86(const String& type);
/*! \brief Create default schedule rules for CUDA */
TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
We get our new quantized uint8 tensor - scale * (QA' - zp_a')

Similarly we can convert from int8 to uint8.
Similarly we can convert from uint8 to int8.

Parameters
----------
Expand Down Expand Up @@ -449,6 +449,7 @@ def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register("cpu")
def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
# TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
Expand All @@ -457,6 +458,7 @@ def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):

@qnn_dense_legalize.register("cpu")
def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
# TODO(vvchernov): not only VNNI
# The VNNI transformations prefer uint8 x int8 datatypes.
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,28 @@ def _has_vnni():
return False


# check avx512 intrinsic groups for SkyLake X
def _has_slavx512():
# Check LLVM support
llvm_version = tvm.target.codegen.llvm_version_major()
is_llvm_support = llvm_version >= 8
arch = platform.machine()
# Only linux is supported for now.
if arch == "x86_64" and sys.platform.startswith("linux"):
with open("/proc/cpuinfo", "r") as content:
ctx = content.read()
check = (
"avx512f" in ctx
and "avx512cd" in ctx
and "avx512bw" in ctx
and "avx512dq" in ctx
and "avx512vl" in ctx
)
return check and is_llvm_support

return False


requires_arm_dot = Feature("arm_dot", "ARM dot product", run_time_check=_arm_dot_supported)


Expand All @@ -1035,6 +1057,13 @@ def _has_vnni():
)


requires_skylake_avx512 = Feature(
"skylake_avx512",
"x86 SkyLake AVX512",
run_time_check=lambda: _has_slavx512() and _is_intel(),
)


def _cmake_flag_enabled(flag):
flag = tvm.support.libinfo()[flag]

Expand Down
40 changes: 40 additions & 0 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,48 @@ def dot_product_16x4_u8i8i32_vnni(
)


@T.prim_func
def dot_product_16x4_u8i8i32_avx512(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
A_brdcst = T.broadcast(A_i32, 16)
A_u8x64 = T.reinterpret(A_brdcst, dtype="uint8x64")

B_i8x64 = B.vload([0, 0], dtype="int8x64")

Red = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
T.uint32(2),
A_u8x64,
B_i8x64,
dtype="int16x32",
)

C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
T.uint32(2),
Red,
T.int16x32(1),
dtype="int32x16",
)


VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"

TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
)

AVX512_DOT_16x4_INTRIN = "dot_16x4_avx512"

TensorIntrin.register(
AVX512_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_avx512
)
25 changes: 13 additions & 12 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from .. import generic, nn
from ..transform import layout_transform
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
from .dense import dense_vnni_schedule, dense_amx_int8_schedule
from .dense import dense_int8_schedule, dense_amx_int8_schedule
from .injective import schedule_injective_from_existing
from .utils import target_has_vnni, target_has_amx
from .utils import target_has_avx512, target_has_amx


@autotvm.register_topi_compute("batch_matmul_vnni.x86")
@autotvm.register_topi_compute("batch_matmul_int8.x86")
def batch_matmul_int8_compute(cfg, x, y, *_):
"""Compute for uint8 x int8 -> int32 batch_matmul"""
batch, m, k = x.shape
Expand All @@ -39,8 +39,8 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_vnni"}
if target_has_avx512(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
attrs_info = None

Expand All @@ -60,13 +60,14 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
return z


def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using VNNI vpdpbusd instruction"""
def batch_matmul_int8_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using avx512 or lower instructions
including VNNI vpdpbusd instruction if possible"""
# C: The output of batched GEMM
# O: The output of the fused op

# Schedule the GEMM part
s, fused_inner = dense_vnni_schedule(cfg, s, C, O, do_parallel=False)
s, fused_inner = dense_int8_schedule(cfg, s, C, O, do_parallel=False)
# Parallelize over batch
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
Expand Down Expand Up @@ -228,9 +229,9 @@ def _callback(op):
return s


@autotvm.register_topi_schedule("batch_matmul_vnni.x86")
@autotvm.register_topi_schedule("batch_matmul_int8.x86")
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_vnni"""
"""Schedule for batch_matmul_int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

Expand All @@ -239,8 +240,8 @@ def _callback(op):
layout_trans = op.input_tensors[1]
if target_has_amx(mcpu):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_vnni(mcpu):
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_avx512(mcpu):
batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
19 changes: 10 additions & 9 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids
from .tensor_intrin import acc_32x32_int32_sapphirerapids
from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx
from .utils import get_simd_32bit_lanes, target_has_avx512, target_has_amx


def _schedule_dense_pack_template(cfg, s, C, O):
Expand Down Expand Up @@ -302,8 +302,8 @@ def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_vnni(mcpu):
dense_vnni_schedule(cfg, s, op.output(0), outs[0])
elif target_has_avx512(mcpu):
dense_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s
Expand All @@ -315,8 +315,8 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"}
if target_has_avx512(mcpu):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
target_attr = None

Expand All @@ -339,8 +339,9 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
return C


def dense_vnni_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using VNNI vpdpbusd instruction"""
def dense_int8_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using avx512 or lower instructions
including VNNI vpdpbusd instruction if possible"""
# C: The output of GEMM
# O: The output of the fused op
def split_y(out):
Expand All @@ -361,7 +362,7 @@ def split_y(out):

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_16x1x16_uint8_int8_int32_cascadelake()
pc = dot_16x1x16_uint8_int8_int32()
s[C].tensorize(a_xi, pc)

if C == O:
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from .dense import _default_dense_pack_config
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
from .utils import target_has_vnni
from .utils import target_has_amx
from .utils import target_has_avx512, target_has_amx
from .. import nn


def check_inst_applicable(x, y, allow_padding=False):
def check_int8_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu)
# TODO(vvchernov): may be also target_has_avx2 or lower?
simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu)
return (
simd_avai
and "int8" in x.dtype
Expand All @@ -49,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)

if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
if check_int8_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

Expand Down Expand Up @@ -86,10 +86,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
return None


def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
if (
check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True)
check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True)
and arg_types[0].dtype == "int8"
):
x, y = inputs
Expand Down Expand Up @@ -135,12 +135,12 @@ def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
@nn.dense_legalize.register("cpu")
def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 dense for VNNI."""
return vnni_legalize(inputs, arg_types, relay.nn.dense, attrs)
return int8_int8_legalize(inputs, arg_types, relay.nn.dense, attrs)


@nn.batch_matmul_legalize.register("cpu")
def _batch_matmul_legalize(attrs, inputs, arg_types):
"""Legalizes s8, s8 -> s32 batch_matmul for VNNI."""
if attrs["transpose_a"] or not attrs["transpose_b"]:
return None
return vnni_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True)
return int8_int8_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True)
2 changes: 0 additions & 2 deletions src/meta_schedule/mutator/mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ Map<Mutator, FloatImm> Mutator::DefaultLLVM() {
{Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}};
}

Map<Mutator, FloatImm> Mutator::DefaultVNNI() { return Mutator::DefaultLLVM(); }

Map<Mutator, FloatImm> Mutator::DefaultCUDA() {
return Map<Mutator, FloatImm>{
{Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Array<Postproc> Postproc::DefaultLLVM() {
};
}

Array<Postproc> Postproc::DefaultVNNI() {
Array<Postproc> Postproc::DefaultCPUTensorization() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true),
Expand Down
6 changes: 4 additions & 2 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
};
}

Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
Array<ScheduleRule> ScheduleRule::DefaultX86(const String& type) {
static const Map<String, String> intrins = {{"vnni", "dot_16x4_vnni"},
{"avx512", "dot_16x4_avx512"}};
return {
ScheduleRule::ApplyCustomRule(),
ScheduleRule::InlineConstantScalars(),
Expand All @@ -101,7 +103,7 @@ Array<ScheduleRule> ScheduleRule::DefaultVNNI() {
/*max_jobs_per_core=*/16,
/*max_innermost_factor=*/Integer(64)),
ScheduleRule::MultiLevelTilingWithIntrin(
/*intrin_name=*/"dot_16x4_vnni",
/*intrin_name=*/intrins[type],
/*structure=*/"SSRSRS",
/*tile_binds=*/NullOpt,
/*max_innermost_factor=*/Integer(64),
Expand Down
Loading