Skip to content

Commit

Permalink
Fix vmlal.s16 code generation for int8 x int8 -> int32
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtulloch committed Mar 8, 2019
1 parent fe06049 commit 17112b5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/pass/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,23 @@ class IntrinInjecter : public IRMutator {
// on ARM.
if (const Broadcast* bcast = e.as<Broadcast>()) {
if (const Cast* cast = bcast->value.as<Cast>()) {
if (cast->type.bits() == cast->value.type().bits() * 2) {
auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
if (cast->type.bits() == cast->value.type().bits() * 2) {
return true;
}
// Check both operands are integer-like.
if (!cast->type.is_uint() && !cast->type.is_int()) {
return false;
}
if (!cast->value.type().is_uint() && !cast->value.type().is_int()) {
return false;
}
// If both are integer-like, swap if we have a widening cast.
return cast->type.bits() > cast->value.type().bits();
};

if (should_swap()) {
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
return Cast::make(bcast->type, new_bcast);
}
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_codegen_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,49 @@ def check_correct_assembly(type, elements, counts):
check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3)

def test_vmlal_s16():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'

def check_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='A')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
s = tvm.create_schedule(C.op)
s[C].vectorize(s[C].op.axis[0])
f = tvm.build(s, [A, B, C], target)

# Verify we see the correct number of vmlal.s16 instructions
assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly)
assert (len(matches) == N // 4)
check_correct_assembly(4)
check_correct_assembly(8)
check_correct_assembly(16)

def check_broadcast_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='A')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k].astype("int32"),
axis=[k]), name='C')
s = tvm.create_schedule(C.op)
s[C].vectorize(s[C].op.axis[0])
f = tvm.build(s, [A, B, C], target)

# Verify we see the correct number of vmlal.s16 instructions
assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly)
assert len(matches) == N // 4
check_broadcast_correct_assembly(8)
check_broadcast_correct_assembly(16)
check_broadcast_correct_assembly(32)
check_broadcast_correct_assembly(64)

if __name__ == "__main__":
test_popcount()
test_vmlal_s16()

0 comments on commit 17112b5

Please sign in to comment.