From 155e955fe14f0d320dc76258ae2c3bbe92d4bb71 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 25 Oct 2018 00:37:13 +0800 Subject: [PATCH] Fix int8x4 broadcast value codegen in cuda (#1959) --- src/codegen/codegen_cuda.cc | 10 ++++++++++ tests/python/unittest/test_codegen_cuda.py | 23 ++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 2ed8d8e3ff78..0ab56a116eab 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -273,6 +273,16 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) + if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) { + // make_int8x4 + const int64_t *p = as_const_int(op->value); + CHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + os << "(int)" << v; + return; + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->type, os); diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index a0b1cf445ba6..d3b770790bdb 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -87,7 +87,30 @@ def check_cuda(dtype, n, lanes): check_cuda("int8", 64, 8) check_cuda("int8", 64, 16) +def test_cuda_make_int8x4(): + def check_cuda(n, value): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + lanes = 4 + dtype = 'int8' + ctx = tvm.gpu(0) + A = tvm.compute((n, lanes), lambda i,j: tvm.const(value, dtype=dtype)) + s = tvm.create_schedule(A.op) + y, x = s[A].op.axis + s[A].vectorize(x) + s[A].bind(y, tvm.thread_axis("blockIdx.x")) + fun = tvm.build(s, [A], "cuda", name="make_int8x4") + np_a = np.full((n, lanes), value, dtype=dtype) + a = tvm.nd.empty(np_a.shape, dtype, ctx) + fun(a) + np.testing.assert_equal(a.asnumpy(), np_a) + check_cuda(64, 0xAB) + check_cuda(64, 0) + check_cuda(64, -3) + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() test_cuda_vectorize_load() + test_cuda_make_int8x4()