Skip to content

Commit

Permalink
[TIR][TOPI][CI] Fix number of arguments in calls of llvm_pure_intrin (a…
Browse files Browse the repository at this point in the history
…pache#13881)

fix number of arguments in calls of llvm_pure_intrin

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Jan 31, 2023
1 parent 7374038 commit 206f085
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dot_product_16x4_u8i8i32_vnni(

C[T.ramp(T.int32(0), 1, 16)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.uint32(3),
C_i32x16,
T.broadcast(A_i32, 16),
B_i32x16,
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/topi/x86/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def _instr(index):
pair_reduction = tvm.tir.call_llvm_pure_intrin(
int_lx32,
pmaddubs,
tvm.tir.const(0, "uint32"),
tvm.tir.const(2, "uint32"),
vec_a,
vec_b,
)
quad_reduction = tvm.tir.call_llvm_pure_intrin(
int_32xl,
pmaddw,
tvm.tir.const(0, "uint32"),
tvm.tir.const(2, "uint32"),
pair_reduction,
vec_one,
)
Expand Down Expand Up @@ -215,7 +215,7 @@ def _instr(index):
pair_reduction = tvm.tir.call_llvm_pure_intrin(
"int16x32",
"llvm.x86.avx512.pmaddubs.w.512",
tvm.tir.const(0, "uint32"),
tvm.tir.const(2, "uint32"),
vec_a,
vec_b,
)
Expand Down Expand Up @@ -309,7 +309,7 @@ def _instr(index):
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x16",
"llvm.x86.avx512.vpdpbusd.512",
tvm.tir.const(0, "uint32"),
tvm.tir.const(3, "uint32"),
vec_c,
vec_ai32,
vec_bi32,
Expand All @@ -321,14 +321,14 @@ def _instr(index):
pair_reduction = tvm.tir.call_llvm_pure_intrin(
"int16x32",
"llvm.x86.avx512.pmaddubs.w.512",
tvm.tir.const(0, "uint32"),
tvm.tir.const(2, "uint32"),
vec_a,
vec_b,
)
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x16",
"llvm.x86.avx512.pmaddw.d.512",
tvm.tir.const(0, "uint32"),
tvm.tir.const(2, "uint32"),
pair_reduction,
vec_one,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def main(
C_i32x16 = C.vload([0], dtype="int32x16")
C[T.ramp(0, 1, 16)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.uint32(3),
C_i32x16,
T.broadcast(A_i32, 16),
B_i32x16,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_trace_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ def main(p0: T.Buffer[(1, 32, 7, 7, 16), "uint8"], p1: T.Buffer[(128, 32, 1, 1,
B_i8x64: T.int8x64 = B[0, 0:64]
B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
C_i32x16: T.int32x16 = C[0:16]
C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(3), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7):
for ax4_fused in T.vectorized(16):
with T.block("T_cast_8"):
Expand Down

0 comments on commit 206f085

Please sign in to comment.