Skip to content

Commit

Permalink
Fix conv2_gemm after target structure update (#6037)
Browse files Browse the repository at this point in the history
After target structure changed in this RFC:

https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844/42

The conv2d optimizations was broken for the following reasons:
- "target" is now called mtriple (this changes how we test if the
  architecture is AArch64)
- when we invoke "clang.create_llvm" we still need to specify the
  "--target" option (set to aarch64-linux-gnu)

This submission reverts those changes

Change-Id: I04c597b91ca5800ddf4471255e2a358c60bc048e
  • Loading branch information
Giuseppe Rossini authored Jul 14, 2020
1 parent 99c52f3 commit 67ed6d0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def is_fast_int8_on_arm():
def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in target.attrs.get("target", "")
return 'aarch64' in target.attrs.get("mtriple", "")

########################
# ARM CPU legalizations.
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in target.attrs.get("target", "")
return 'aarch64' in target.attrs.get("mtriple", "")


# Compute function
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code,
options=["-mtriple=aarch64-linux-gnu -mattr=+neon"],
options=["--target=aarch64-linux-gnu -mattr=+neon"],
output=ll_path)
return ll_code

Expand Down
64 changes: 64 additions & 0 deletions topi/tests/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,70 @@
from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple
from topi.arm_cpu.conv2d_gemm import is_aarch64_arm

from common import get_all_backend, Int8Fallback

def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
kernel, stride, padding_sum, dilation))

in_height = in_width = in_size
A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
bias = te.placeholder((num_filter,), name='bias', dtype='int8')
dtype = 'int32'
device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"

ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Compiling on arm AArch64 target: %s" % device)
with tvm.target.create(device):
assert is_aarch64_arm(), "AArch64 target not recognized"

C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
(dilation, dilation), dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

if add_bias:
tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))

def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding,
dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
Expand Down Expand Up @@ -409,6 +470,9 @@ def test_conv2d_nhwc():
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True)
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True)

# Let's also verify that it compiles fine on AArch64 targets
compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, 'SAME')


if __name__ == "__main__":
test_conv2d_nchw()
Expand Down

0 comments on commit 67ed6d0

Please sign in to comment.