Skip to content

Commit

Permalink
[CODEGEN][OPENCL] Fix compile error about ternary expression. (#2821)
Browse files Browse the repository at this point in the history
Code like this can't be built with NV OpenCL, and it needs an explicit type
  converison for ternary expression if return type is uchar.

       uchar i = 0, j = 0;
       uchar t = max((uchar)j, ((i > 0) ? (uchar)1 : (uchar)0));
  • Loading branch information
lixiaoquan authored and kazum committed Mar 18, 2019
1 parent 0f6989f commit fa70983
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].type(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.type(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC {

// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)

private:
// whether enable fp16 and fp64 extension
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_codegen_opencl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tvm

target = 'opencl'

def test_opencl_ternary_expression():
def check_if_then_else(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

def check_select(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_if_then_else(ctx, 1, 'int8')
check_if_then_else(ctx, 1, 'uint8')
check_if_then_else(ctx, 1, 'int16')
check_if_then_else(ctx, 1, 'uint16')
check_select(ctx, 1, 'int8')
check_select(ctx, 1, 'uint8')
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')


if __name__ == "__main__":
test_opencl_ternary_expression()

0 comments on commit fa70983

Please sign in to comment.