Skip to content

Commit

Permalink
Use auto-tuner to improve conv2d_gemm performance
Browse files Browse the repository at this point in the history
The following tuning entities have been introduced:
- Unrolling and vectorizing input matrix transform
- Reordering gemm to exploit parallel threads
- Unrolling `gemm_quantized` intrinsic
- Interleaving `gemm_quantized` intrinsic

Change-Id: Icd3ab005663f78a80672e71ef368f6d0efa4a401
  • Loading branch information
Giuseppe Rossini committed Jul 20, 2020
1 parent a3b600a commit f3565f2
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 155 deletions.
62 changes: 47 additions & 15 deletions topi/python/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
"""GEMM Convolution schedule on ARM"""
import tvm
from tvm import te
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
from topi import nn
from ..util import get_const_tuple
from ..util import get_const_tuple, get_const_int
from ..nn.util import get_pad_tuple
from .tensor_intrin import gemv_quantized, gemv_quantized_impl
from .tensor_intrin import gemm_quantized, gemm_quantized_impl

def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
Expand All @@ -38,15 +39,15 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
executing GEMM and transforming the output back"""
batches, IH, IW, IC = get_const_tuple(data.shape)

KH, KW = kernel_size
OC = output_channels
KH, KW = get_const_tuple(kernel_size)
OC = get_const_int(output_channels)

K_AREA = KH * KW

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
dilation_h, dilation_w = get_const_tuple(dilation)

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
Expand Down Expand Up @@ -126,6 +127,28 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
name='conv2d_gemm_output')


# Configuration space
x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
cfg.define_reorder('reorder_gemm',
[x, y],
policy='candidate',
candidate=[[x, y],
[y, x]])

outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
cfg.define_annotate("A_interleaved_unroll_vec",
[outer_loop, inner_loop],
policy="try_unroll_vec")
cfg.define_knob('gemm_quantized_unroll', [True, False])
cfg.define_knob('gemm_quantized_interleave', [True, False])

# Fallback configuration
if cfg.is_fallback:
cfg['reorder_gemm'] = ReorderEntity([0, 1])
cfg['A_interleaved_unroll_vec'] = AnnotateEntity(["unroll", "vec"])
cfg['gemm_quantized_unroll'] = OtherOptionEntity(False)
cfg['gemm_quantized_interleave'] = OtherOptionEntity(True)
return out

# Schedules
Expand All @@ -150,33 +173,42 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
n_outer, n_inner = s[data_im2col].split(n, 16)
s[data_im2col].unroll(n_outer)
s[data_im2col].vectorize(n_inner)
s[data_im2col].parallel(m)
else:
s[data_im2col].compute_inline()

# Computation(through tensorize)
b, xo, yo, xi, yi = C_interleaved.op.axis
s[C_interleaved].reorder(xo, yo, yi, xi)
s[C_interleaved].parallel(xo)
s[A_interleaved].compute_at(s[C_interleaved], xo)
s[A_interleaved].vectorize(A_interleaved.op.axis[4])
outer_gemm, inner_gemm = cfg['reorder_gemm'].apply(s, C_interleaved, [xo, yo])
s[C_interleaved].reorder(yi, xi)
s[C_interleaved].parallel(outer_gemm)
s[A_interleaved].compute_at(s[C_interleaved], outer_gemm)
_, _, _, outer_A_interleaved, inner_A_interleaved = A_interleaved.op.axis
cfg['A_interleaved_unroll_vec'].apply(s,
A_interleaved,
[outer_A_interleaved, inner_A_interleaved])

in_type = A_interleaved.dtype
out_type = C.dtype
if is_aarch64_arm() and out_type == 'int32':
K = A_interleaved_input.shape[2]
_, M, N = C.shape
assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"

gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type)
s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type))
s[C_interleaved].tensorize(yi, gem_v_dotprod)
unroll = cfg['gemm_quantized_unroll'].val
interleave = cfg['gemm_quantized_interleave'].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(xo, "import_llvm", gemm_quantized_impl(M,
N,
K,
unroll,
interleave,
in_type))
s[C_interleaved].tensorize(yi, gemm)

# Output transform
if out != final_out:
n, h, w, c = out.op.axis
_, inner = s[out].split(c, 4)
s[C].compute_at(s[out], inner)
s[out].vectorize(inner)


return s
1 change: 1 addition & 0 deletions topi/python/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def schedule_conv2d_NHWC_quantized(cfg, outs):
n, h, w, c = out.op.axis
outer, inner = s[out].split(c, 4)
s[out].vectorize(inner)
s[out].parallel(h)

def _callback(op):
"""Traverse operators from computation graph"""
Expand Down
Loading

0 comments on commit f3565f2

Please sign in to comment.