From dd2897cb69d56b36a2d15daf9a43cc22f116b4c7 Mon Sep 17 00:00:00 2001 From: mhyang-pllab <75776819+mhyang-pllab@users.noreply.github.com> Date: Sat, 28 May 2022 16:36:16 +0800 Subject: [PATCH] [OpenCL] Avoid SelectNode ambiguous overloading (#11488) * [OpenCL] Avoid SelectNode ambiguous overloading * Revert "[OpenCL] Avoid SelectNode ambiguous overloading" This reverts commit 60f68d2e7f750a0f8e62536da7b3327d1f5f29c1. * [OpenCL] Avoid SelectNode ambiguous codegen --- src/target/source/codegen_opencl.cc | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 1fdf1e7bed4e..5d04d00339fc 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -541,12 +541,26 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { } void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { + std::ostringstream oss; os << "select("; - PrintExpr(op->false_value, os); + PrintExpr(op->false_value, oss); + os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype); + oss.str(""); os << ", "; - PrintExpr(op->true_value, os); + PrintExpr(op->true_value, oss); + os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype); + oss.str(""); os << ", "; - PrintExpr(op->condition, os); + PrintExpr(op->condition, oss); + if (op->dtype.is_float()) { + if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) { + os << oss.str(); + } else { + os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); + } + } else { + os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); + } os << ")"; }