diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index c8e1a5a7b378a..62f013a921512 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -20,9 +20,10 @@ import tvm from tvm import te from tvm.topi import nn -from ..util import get_const_tuple +from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity +from ..util import get_const_tuple, get_const_int from ..nn.util import get_pad_tuple -from .tensor_intrin import gemv_quantized, gemv_quantized_impl +from .tensor_intrin import gemm_quantized, gemm_quantized_impl def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ @@ -38,15 +39,15 @@ def compute_conv2d_gemm_without_weight_transform(cfg, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) - KH, KW = kernel_size - OC = output_channels + KH, KW = get_const_tuple(kernel_size) + OC = get_const_int(output_channels) K_AREA = KH * KW if isinstance(dilation, int): dilation_h = dilation_w = dilation else: - dilation_h, dilation_w = dilation + dilation_h, dilation_w = get_const_tuple(dilation) dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 @@ -126,6 +127,28 @@ def compute_conv2d_gemm_without_weight_transform(cfg, out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name='conv2d_gemm_output') + + # Configuration space + x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16) + cfg.define_reorder('reorder_gemm', + [x, y], + policy='candidate', + candidate=[[x, y], + [y, x]]) + + outer_loop, inner_loop = cfg.axis(4), cfg.axis(16) + cfg.define_annotate("A_interleaved_unroll_vec", + [outer_loop, inner_loop], + policy="try_unroll_vec") + cfg.define_knob('gemm_quantized_unroll', [True, False]) + cfg.define_knob('gemm_quantized_interleave', [True, False]) + + # Fallback configuration + if cfg.is_fallback: + cfg['reorder_gemm'] = ReorderEntity([0, 1]) + cfg['A_interleaved_unroll_vec'] = AnnotateEntity(["unroll", "vec"]) + cfg['gemm_quantized_unroll'] = OtherOptionEntity(False) + cfg['gemm_quantized_interleave'] = OtherOptionEntity(True) return out # Schedules @@ -150,15 +173,22 @@ def schedule_conv2d_gemm(cfg, s, out, final_out): n_outer, n_inner = s[data_im2col].split(n, 16) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) + b_m_fused = s[data_im2col].fuse(b, m) + s[data_im2col].parallel(b_m_fused) else: s[data_im2col].compute_inline() # Computation(through tensorize) b, xo, yo, xi, yi = C_interleaved.op.axis - s[C_interleaved].reorder(xo, yo, yi, xi) - s[C_interleaved].parallel(xo) - s[A_interleaved].compute_at(s[C_interleaved], xo) - s[A_interleaved].vectorize(A_interleaved.op.axis[4]) + outer_gemm, inner_gemm = cfg['reorder_gemm'].apply(s, C_interleaved, [xo, yo]) + s[C_interleaved].reorder(yi, xi) + b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm) + s[C_interleaved].parallel(b_outer_gemm_fused) + s[A_interleaved].compute_at(s[C_interleaved], b_outer_gemm_fused) + _, _, _, outer_A_interleaved, inner_A_interleaved = A_interleaved.op.axis + cfg['A_interleaved_unroll_vec'].apply(s, + A_interleaved, + [outer_A_interleaved, inner_A_interleaved]) in_type = A_interleaved.dtype out_type = C.dtype @@ -166,10 +196,16 @@ def schedule_conv2d_gemm(cfg, s, out, final_out): K = A_interleaved_input.shape[2] _, M, N = C.shape assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported" - - gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type) - s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type)) - s[C_interleaved].tensorize(yi, gem_v_dotprod) + unroll = cfg['gemm_quantized_unroll'].val + interleave = cfg['gemm_quantized_interleave'].val + gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type) + s[C_interleaved].pragma(b_outer_gemm_fused, "import_llvm", gemm_quantized_impl(M, + N, + K, + unroll, + interleave, + in_type)) + s[C_interleaved].tensorize(yi, gemm) # Output transform if out != final_out: @@ -177,6 +213,4 @@ def schedule_conv2d_gemm(cfg, s, out, final_out): _, inner = s[out].split(c, 4) s[C].compute_at(s[out], inner) s[out].vectorize(inner) - - return s diff --git a/python/tvm/topi/arm_cpu/conv2d_int8.py b/python/tvm/topi/arm_cpu/conv2d_int8.py index 89a37fa412947..9a6e8cc7e90b4 100644 --- a/python/tvm/topi/arm_cpu/conv2d_int8.py +++ b/python/tvm/topi/arm_cpu/conv2d_int8.py @@ -140,8 +140,10 @@ def schedule_conv2d_NHWC_quantized(cfg, outs): # Vectorize the output and then inline all the rest out = outs[0] n, h, w, c = out.op.axis + n_h_fused = s[out].fuse(n, h) outer, inner = s[out].split(c, 4) s[out].vectorize(inner) + s[out].parallel(n_h_fused) def _callback(op): """Traverse operators from computation graph""" diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 270bfbe877664..52e67ad1532db 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -21,7 +21,186 @@ from tvm import te from tvm.contrib import util, clang -def gemv_quantized_impl(M, N, data_type='uint8'): +def gemm_quantized_4_4_batched(): + return """ + // First half + // Higher part of a0 * {b0,b1,b2,b3} + "umull v8.8h, v0.8b, v4.8b\\n" + "umull v9.8h, v0.8b, v5.8b\\n" + "umull v10.8h, v0.8b, v6.8b\\n" + "umull v11.8h, v0.8b, v7.8b\\n" + + // Higher part of a1 * {b0,b1,b2,b3} + "umull v12.8h, v1.8b, v4.8b\\n" + "umull v13.8h, v1.8b, v5.8b\\n" + "umull v14.8h, v1.8b, v6.8b\\n" + "umull v15.8h, v1.8b, v7.8b\\n" + + // Accumulate + "uadalp v16.4s, v8.8h\\n" + "uadalp v17.4s, v9.8h\\n" + "uadalp v18.4s, v10.8h\\n" + "uadalp v19.4s, v11.8h\\n" + "uadalp v20.4s, v12.8h\\n" + "uadalp v21.4s, v13.8h\\n" + "uadalp v22.4s, v14.8h\\n" + "uadalp v23.4s, v15.8h\\n" + + // Lower part of a0 * {b0,b1,b2,b3} + "umull2 v8.8h, v0.16b, v4.16b\\n" + "umull2 v9.8h, v0.16b, v5.16b\\n" + "umull2 v10.8h, v0.16b, v6.16b\\n" + "umull2 v11.8h, v0.16b, v7.16b\\n" + + // Lower part of a1 * {b0,b1,b2,b3} + "umull2 v12.8h, v1.16b, v4.16b\\n" + "umull2 v13.8h, v1.16b, v5.16b\\n" + "umull2 v14.8h, v1.16b, v6.16b\\n" + "umull2 v15.8h, v1.16b, v7.16b\\n" + + // Accumulate again + "uadalp v16.4s, v8.8h\\n" + "uadalp v17.4s, v9.8h\\n" + "uadalp v18.4s, v10.8h\\n" + "uadalp v19.4s, v11.8h\\n" + "uadalp v20.4s, v12.8h\\n" + "uadalp v21.4s, v13.8h\\n" + "uadalp v22.4s, v14.8h\\n" + "uadalp v23.4s, v15.8h\\n" + + // Second half + // Lower part of a2 * {b0,b1,b2,b3} + "umull v8.8h, v2.8b, v4.8b\\n" + "umull v9.8h, v2.8b, v5.8b\\n" + "umull v10.8h, v2.8b, v6.8b\\n" + "umull v11.8h, v2.8b, v7.8b\\n" + + // Lower part of a3 * {b0,b1,b2,b3} + "umull v12.8h, v3.8b, v4.8b\\n" + "umull v13.8h, v3.8b, v5.8b\\n" + "umull v14.8h, v3.8b, v6.8b\\n" + "umull v15.8h, v3.8b, v7.8b\\n" + + // Accumulate + "uadalp v24.4s, v8.8h\\n" + "uadalp v25.4s, v9.8h\\n" + "uadalp v26.4s, v10.8h\\n" + "uadalp v27.4s, v11.8h\\n" + "uadalp v28.4s, v12.8h\\n" + "uadalp v29.4s, v13.8h\\n" + "uadalp v30.4s, v14.8h\\n" + "uadalp v31.4s, v15.8h\\n" + + // Higher part of a2 * {b0,b1,b2,b3} + "umull2 v8.8h, v2.16b, v4.16b\\n" + "umull2 v9.8h, v2.16b, v5.16b\\n" + "umull2 v10.8h, v2.16b, v6.16b\\n" + "umull2 v11.8h, v2.16b, v7.16b\\n" + + // Higher part of a3 * {b0,b1,b2,b3} + "umull2 v12.8h, v3.16b, v4.16b\\n" + "umull2 v13.8h, v3.16b, v5.16b\\n" + "umull2 v14.8h, v3.16b, v6.16b\\n" + "umull2 v15.8h, v3.16b, v7.16b\\n" + + // Accumulate again + "uadalp v24.4s, v8.8h\\n" + "uadalp v25.4s, v9.8h\\n" + "uadalp v26.4s, v10.8h\\n" + "uadalp v27.4s, v11.8h\\n" + "uadalp v28.4s, v12.8h\\n" + "uadalp v29.4s, v13.8h\\n" + "uadalp v30.4s, v14.8h\\n" + "uadalp v31.4s, v15.8h\\n" + """ + +def gemm_quantized_4_4_interleaved(): + return """ + // First half + // Higher part of a0 * {b0,b1,b2,b3} and accumulate + "umull v8.8h, v0.8b, v4.8b\\n" + "uadalp v16.4s, v8.8h\\n" + "umull v9.8h, v0.8b, v5.8b\\n" + "uadalp v17.4s, v9.8h\\n" + "umull v10.8h, v0.8b, v6.8b\\n" + "uadalp v18.4s, v10.8h\\n" + "umull v11.8h, v0.8b, v7.8b\\n" + "uadalp v19.4s, v11.8h\\n" + + // Higher part of a1 * {b0,b1,b2,b3} and accumulate + "umull v12.8h, v1.8b, v4.8b\\n" + "uadalp v20.4s, v12.8h\\n" + "umull v13.8h, v1.8b, v5.8b\\n" + "uadalp v21.4s, v13.8h\\n" + "umull v14.8h, v1.8b, v6.8b\\n" + "uadalp v22.4s, v14.8h\\n" + "umull v15.8h, v1.8b, v7.8b\\n" + "uadalp v23.4s, v15.8h\\n" + + // Lower part of a0 * {b0,b1,b2,b3} and accumulate + "umull2 v8.8h, v0.16b, v4.16b\\n" + "uadalp v16.4s, v8.8h\\n" + "umull2 v9.8h, v0.16b, v5.16b\\n" + "uadalp v17.4s, v9.8h\\n" + "umull2 v10.8h, v0.16b, v6.16b\\n" + "uadalp v18.4s, v10.8h\\n" + "umull2 v11.8h, v0.16b, v7.16b\\n" + "uadalp v19.4s, v11.8h\\n" + + // Lower part of a1 * {b0,b1,b2,b3} and accumulate + "umull2 v12.8h, v1.16b, v4.16b\\n" + "uadalp v20.4s, v12.8h\\n" + "umull2 v13.8h, v1.16b, v5.16b\\n" + "uadalp v21.4s, v13.8h\\n" + "umull2 v14.8h, v1.16b, v6.16b\\n" + "uadalp v22.4s, v14.8h\\n" + "umull2 v15.8h, v1.16b, v7.16b\\n" + "uadalp v23.4s, v15.8h\\n" + + // Second half + // Higher part of a2 * {b0,b1,b2,b3} and accumulate + "umull v8.8h, v2.8b, v4.8b\\n" + "uadalp v24.4s, v8.8h\\n" + "umull v9.8h, v2.8b, v5.8b\\n" + "uadalp v25.4s, v9.8h\\n" + "umull v10.8h, v2.8b, v6.8b\\n" + "uadalp v26.4s, v10.8h\\n" + "umull v11.8h, v2.8b, v7.8b\\n" + "uadalp v27.4s, v11.8h\\n" + + // Higher part of a3 * {b0,b1,b2,b3} and accumulate + "umull v12.8h, v3.8b, v4.8b\\n" + "uadalp v28.4s, v12.8h\\n" + "umull v13.8h, v3.8b, v5.8b\\n" + "uadalp v29.4s, v13.8h\\n" + "umull v14.8h, v3.8b, v6.8b\\n" + "uadalp v30.4s, v14.8h\\n" + "umull v15.8h, v3.8b, v7.8b\\n" + "uadalp v31.4s, v15.8h\\n" + + // Lower part of a2 * {b0,b1,b2,b3} and accumulate + "umull2 v8.8h, v2.16b, v4.16b\\n" + "uadalp v24.4s, v8.8h\\n" + "umull2 v9.8h, v2.16b, v5.16b\\n" + "uadalp v25.4s, v9.8h\\n" + "umull2 v10.8h, v2.16b, v6.16b\\n" + "uadalp v26.4s, v10.8h\\n" + "umull2 v11.8h, v2.16b, v7.16b\\n" + "uadalp v27.4s, v11.8h\\n" + + // Lower part of a3 * {b0,b1,b2,b3} and accumulate + "umull2 v12.8h, v3.16b, v4.16b\\n" + "uadalp v28.4s, v12.8h\\n" + "umull2 v13.8h, v3.16b, v5.16b\\n" + "uadalp v29.4s, v13.8h\\n" + "umull2 v14.8h, v3.16b, v6.16b\\n" + "uadalp v30.4s, v14.8h\\n" + "umull2 v15.8h, v3.16b, v7.16b\\n" + "uadalp v31.4s, v15.8h\\n" + """ + + +def gemm_quantized_impl(M, N, K, unroll, interleave, data_type='uint8'): """ Assembly implementation of a blocked gemv. Given a block a of shape (4, k) and a block b' of shape (4, k) produces the output block c = a*b of shape (4,4) """ @@ -30,13 +209,21 @@ def gemv_quantized_impl(M, N, data_type='uint8'): stepB = min(4, N) assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation' - cc_code = """ - extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer, - unsigned char *a_buffer, - unsigned char *b_buffer, - int K, int m, int n) - """.format(data_type, stepA, stepB) + signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(data_type, + stepA, + stepB) + if unroll: + signature += ("_" + str(K)) + if interleave: + signature += ("_interleaved") + + signature += """(int *c_buffer, + unsigned char *a_buffer, + unsigned char *b_buffer, + int K, int m, int n)""" + + cc_code = signature cc_code += """ { unsigned char * a_ptr = a_buffer; @@ -65,141 +252,58 @@ def gemv_quantized_impl(M, N, data_type='uint8'): "1:" """ - cc_code += ' "ldr q0, [%[a_ptr]]\\n" ' + main_loop = ' "ldr q0, [%[a_ptr]]\\n" ' if M > 1: - cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" ' + main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" ' else: - cc_code += ' "movi v1.4s, #0\\n" ' + main_loop += ' "movi v1.4s, #0\\n" ' if M > 2: - cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" ' + main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" ' else: - cc_code += ' "movi v2.4s, #0\\n" ' + main_loop += ' "movi v2.4s, #0\\n" ' if M > 3: - cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" ' + main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" ' else: - cc_code += ' "movi v3.4s, #0\\n" ' + main_loop += ' "movi v3.4s, #0\\n" ' - cc_code += ' "ldr q4, [%[b_ptr]]\\n" ' + main_loop += ' "ldr q4, [%[b_ptr]]\\n" ' if N > 1: - cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" ' + main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" ' if N > 2: - cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" ' + main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" ' if N > 3: - cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" ' + main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" ' + + # Main computation can interleave multiply/accumulate instructions + # or schedule them in batches (first all multiplies then all accumulates) + if interleave: + main_loop += gemm_quantized_4_4_interleaved() + else: + main_loop += gemm_quantized_4_4_batched() - cc_code += """ - // First half - // Higher part of a0 * {b0,b1,b2,b3} - "umull v8.8h, v0.8b, v4.8b\\n" - "umull v9.8h, v0.8b, v5.8b\\n" - "umull v10.8h, v0.8b, v6.8b\\n" - "umull v11.8h, v0.8b, v7.8b\\n" - - // Higher part of a1 * {b0,b1,b2,b3} - "umull v12.8h, v1.8b, v4.8b\\n" - "umull v13.8h, v1.8b, v5.8b\\n" - "umull v14.8h, v1.8b, v6.8b\\n" - "umull v15.8h, v1.8b, v7.8b\\n" - - // Accumulate - "uadalp v16.4s, v8.8h\\n" - "uadalp v17.4s, v9.8h\\n" - "uadalp v18.4s, v10.8h\\n" - "uadalp v19.4s, v11.8h\\n" - "uadalp v20.4s, v12.8h\\n" - "uadalp v21.4s, v13.8h\\n" - "uadalp v22.4s, v14.8h\\n" - "uadalp v23.4s, v15.8h\\n" - - // Lower part of a0 * {b0,b1,b2,b3} - "umull2 v8.8h, v0.16b, v4.16b\\n" - "umull2 v9.8h, v0.16b, v5.16b\\n" - "umull2 v10.8h, v0.16b, v6.16b\\n" - "umull2 v11.8h, v0.16b, v7.16b\\n" - - // Lower part of a1 * {b0,b1,b2,b3} - "umull2 v12.8h, v1.16b, v4.16b\\n" - "umull2 v13.8h, v1.16b, v5.16b\\n" - "umull2 v14.8h, v1.16b, v6.16b\\n" - "umull2 v15.8h, v1.16b, v7.16b\\n" - - // Accumulate again - "uadalp v16.4s, v8.8h\\n" - "uadalp v17.4s, v9.8h\\n" - "uadalp v18.4s, v10.8h\\n" - "uadalp v19.4s, v11.8h\\n" - "uadalp v20.4s, v12.8h\\n" - "uadalp v21.4s, v13.8h\\n" - "uadalp v22.4s, v14.8h\\n" - "uadalp v23.4s, v15.8h\\n" - - // Second half - - // Lower part of a2 * {b0,b1,b2,b3} - "umull v8.8h, v2.8b, v4.8b\\n" - "umull v9.8h, v2.8b, v5.8b\\n" - "umull v10.8h, v2.8b, v6.8b\\n" - "umull v11.8h, v2.8b, v7.8b\\n" - - // Lower part of a3 * {b0,b1,b2,b3} - "umull v12.8h, v3.8b, v4.8b\\n" - "umull v13.8h, v3.8b, v5.8b\\n" - "umull v14.8h, v3.8b, v6.8b\\n" - "umull v15.8h, v3.8b, v7.8b\\n" - - // Accumulate - "uadalp v24.4s, v8.8h\\n" - "uadalp v25.4s, v9.8h\\n" - "uadalp v26.4s, v10.8h\\n" - "uadalp v27.4s, v11.8h\\n" - "uadalp v28.4s, v12.8h\\n" - "uadalp v29.4s, v13.8h\\n" - "uadalp v30.4s, v14.8h\\n" - "uadalp v31.4s, v15.8h\\n" - - // Higher part of a2 * {b0,b1,b2,b3} - "umull2 v8.8h, v2.16b, v4.16b\\n" - "umull2 v9.8h, v2.16b, v5.16b\\n" - "umull2 v10.8h, v2.16b, v6.16b\\n" - "umull2 v11.8h, v2.16b, v7.16b\\n" - - // Higher part of a3 * {b0,b1,b2,b3} - "umull2 v12.8h, v3.16b, v4.16b\\n" - "umull2 v13.8h, v3.16b, v5.16b\\n" - "umull2 v14.8h, v3.16b, v6.16b\\n" - "umull2 v15.8h, v3.16b, v7.16b\\n" - - // Accumulate again - "uadalp v24.4s, v8.8h\\n" - "uadalp v25.4s, v9.8h\\n" - "uadalp v26.4s, v10.8h\\n" - "uadalp v27.4s, v11.8h\\n" - "uadalp v28.4s, v12.8h\\n" - "uadalp v29.4s, v13.8h\\n" - "uadalp v30.4s, v14.8h\\n" - "uadalp v31.4s, v15.8h\\n" - """ blockA = min(64, M * 16) blockB = min(64, N * 16) - - cc_code += """ - // Increment pointers and decrement k - "add %[a_ptr], %[a_ptr], #{0}\\n" - "add %[b_ptr], %[b_ptr], #{1}\\n" - "subs %w[k], %w[k], #1\\n" - """.format(blockA, blockB) - - stepC = min(4, N) - + main_loop += """// Increment pointers + "add %[a_ptr], %[a_ptr], #{0}\\n" + "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(blockA, blockB) + + if unroll: + k = int(K//16) + for l in range(0, k): + cc_code += main_loop + else: + cc_code += main_loop + cc_code += """ + "subs %w[k], %w[k], #1\\n" + "cbnz %w[k], 1b\\n" + """ cc_code += """ - "cbnz %w[k], 1b\\n" - // Final additions // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d) @@ -237,6 +341,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'): "str q16, [%[c_ptr]]\\n" """ + stepC = min(4, N) if M > 1: cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4) @@ -272,7 +377,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'): return ll_code -def gemv_quantized(M, N, K, in_type, out_type): +def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type): """ Use integer ARM v8 instructions in order to produce a block c of 4x4 elements given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final @@ -331,23 +436,17 @@ def _instr(): cc = outs[0] stepA = min(4, M) stepB = min(4, N) - - if in_type == 'int8': - ib.emit(tvm.tir.call_extern("int32", - "gemv_int8_int8_int32_{0}_{1}".format(stepA, stepB), - outs[0].access_ptr("w"), - a_buffer.access_ptr("r"), - b_buffer.access_ptr("r"), - K)) - else: - ib.emit(tvm.tir.call_extern("int32", - "gemv_uint8_uint8_int32_{0}_{1}".format(stepA, stepB), - c_buffer.access_ptr("w"), - a_buffer.access_ptr("r"), - b_buffer.access_ptr("r"), - K, - C.shape[0], # m, very useful for debug - C.shape[1])) # n, very useful for debug + intrin_name = "gemm_quantized_{0}_{0}_int32_{1}_{2}".format(in_type, stepA, stepB) + if unroll: + intrin_name += ("_" + str(K)) + if interleave: + intrin_name += "_interleaved" + ib.emit(tvm.tir.call_extern("int32", + intrin_name, + outs[0].access_ptr("w"), + a_buffer.access_ptr("r"), + b_buffer.access_ptr("r"), + K)) return ib.get() # body, reset, update