diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 1fdf1e7bed4e..5d04d00339fc 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -541,12 +541,26 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { } void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { + std::ostringstream oss; os << "select("; - PrintExpr(op->false_value, os); + PrintExpr(op->false_value, oss); + os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype); + oss.str(""); os << ", "; - PrintExpr(op->true_value, os); + PrintExpr(op->true_value, oss); + os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype); + oss.str(""); os << ", "; - PrintExpr(op->condition, os); + PrintExpr(op->condition, oss); + if (op->dtype.is_float()) { + if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) { + os << oss.str(); + } else { + os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); + } + } else { + os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); + } os << ")"; }