From 6b2087818efa083ddd34282f9647c85a0f0df2f8 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Fri, 15 Mar 2019 11:24:22 +0800 Subject: [PATCH] [CODEGEN][OPENCL] Fix compile error about ternary expression. Code like this can't be built with NV OpenCL, and it needs an explicit type converison for ternary expression if return type is uchar. uchar i = 0, j = 0; uchar t = max((uchar)j, ((i > 0) ? (uchar)1 : (uchar)0)); --- src/codegen/codegen_opencl.cc | 19 +++++++ src/codegen/codegen_opencl.h | 2 + tests/python/unittest/test_codegen_opencl.py | 55 ++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 tests/python/unittest/test_codegen_opencl.py diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index a0b3c2000a80..ae87bf9d9b93 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL os << "))"; } +void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) + /* Return type of ternary expression is not always same as its sub-expressions, + * add a cast */ + if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + os << "("; + PrintType(op->args[2].type(), os); + os << ")"; + } + CodeGenC::VisitExpr_(op, os); +} + +void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) + /* Return type of ternary expression is not always same as its sub-expressions, + * add a cast */ + os << "("; + PrintType(op->true_value.type(), os); + os << ")"; + CodeGenC::VisitExpr_(op, os); +} runtime::Module BuildOpenCL(Array funcs) { using tvm::runtime::Registry; diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 90569d176a0b..350b6c0f3402 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_codegen_opencl.py b/tests/python/unittest/test_codegen_opencl.py new file mode 100644 index 000000000000..37aaadc5bb23 --- /dev/null +++ b/tests/python/unittest/test_codegen_opencl.py @@ -0,0 +1,55 @@ +import tvm + +target = 'opencl' + +def test_opencl_ternary_expression(): + def check_if_then_else(ctx, n, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + true_value = tvm.const(1, dtype=dtype) + false_value = tvm.const(3, dtype=dtype) + max_lhs = tvm.const(2, dtype=dtype) + max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value) + C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + def check_select(ctx, n, dtype): + A = tvm.placeholder((n,), name='A', dtype=dtype) + true_value = tvm.const(1, dtype=dtype) + false_value = tvm.const(3, dtype=dtype) + max_lhs = tvm.const(2, dtype=dtype) + max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value) + C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') + s = tvm.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + if not tvm.module.enabled(target): + print("skip because opencl is not enabled..") + return + + ctx = tvm.context(target, 0) + + check_if_then_else(ctx, 1, 'int8') + check_if_then_else(ctx, 1, 'uint8') + check_if_then_else(ctx, 1, 'int16') + check_if_then_else(ctx, 1, 'uint16') + check_select(ctx, 1, 'int8') + check_select(ctx, 1, 'uint8') + check_select(ctx, 1, 'int16') + check_select(ctx, 1, 'uint16') + + +if __name__ == "__main__": + test_opencl_ternary_expression()