Skip to content

Commit

Permalink
Fix int8x4 broadcast value codegen in cuda (apache#1959)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and AWS Neo committed Feb 20, 2019
1 parent 0cc8126 commit 625627c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
}

void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);
CHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
os << "(int)" << v;
return;
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->type, os);
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,30 @@ def check_cuda(dtype, n, lanes):
check_cuda("int8", 64, 8)
check_cuda("int8", 64, 16)

def test_cuda_make_int8x4():
def check_cuda(n, value):
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
lanes = 4
dtype = 'int8'
ctx = tvm.gpu(0)
A = tvm.compute((n, lanes), lambda i,j: tvm.const(value, dtype=dtype))
s = tvm.create_schedule(A.op)
y, x = s[A].op.axis
s[A].vectorize(x)
s[A].bind(y, tvm.thread_axis("blockIdx.x"))
fun = tvm.build(s, [A], "cuda", name="make_int8x4")
np_a = np.full((n, lanes), value, dtype=dtype)
a = tvm.nd.empty(np_a.shape, dtype, ctx)
fun(a)
np.testing.assert_equal(a.asnumpy(), np_a)
check_cuda(64, 0xAB)
check_cuda(64, 0)
check_cuda(64, -3)

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()

0 comments on commit 625627c

Please sign in to comment.