Skip to content

Commit

Permalink
[CODEGEN][OPENCL] Fix compile error about ternary expression.
Browse files Browse the repository at this point in the history
  Code like this can't be built with NV OpenCL, and it needs an explicit type
  converison for ternary expression if return type is uchar.

       uchar i = 0, j = 0;
       uchar t = max((uchar)j, ((i > 0) ? (uchar)1 : (uchar)0));
  • Loading branch information
Li Xiaoquan committed Mar 16, 2019
1 parent 48780c0 commit 6b20878
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].type(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.type(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC {

// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)

private:
// whether enable fp16 and fp64 extension
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_codegen_opencl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tvm

target = 'opencl'

def test_opencl_ternary_expression():
def check_if_then_else(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

def check_select(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_if_then_else(ctx, 1, 'int8')
check_if_then_else(ctx, 1, 'uint8')
check_if_then_else(ctx, 1, 'int16')
check_if_then_else(ctx, 1, 'uint16')
check_select(ctx, 1, 'int8')
check_select(ctx, 1, 'uint8')
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')


if __name__ == "__main__":
test_opencl_ternary_expression()

0 comments on commit 6b20878

Please sign in to comment.