From 60f68d2e7f750a0f8e62536da7b3327d1f5f29c1 Mon Sep 17 00:00:00 2001 From: Ming-Han Yang Date: Fri, 27 May 2022 16:07:19 +0800 Subject: [PATCH] [OpenCL] Avoid SelectNode ambiguous overloading --- src/target/source/codegen_opencl.cc | 10 ---------- src/target/source/codegen_opencl.h | 1 - tests/python/unittest/test_target_codegen_opencl.py | 4 ++-- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 1fdf1e7bed4e5..c29b18bd88e76 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -540,16 +540,6 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { os << ")"; } -void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { - os << "select("; - PrintExpr(op->false_value, os); - os << ", "; - PrintExpr(op->true_value, os); - os << ", "; - PrintExpr(op->condition, os); - os << ")"; -} - void CodeGenOpenCL::SetTextureScope( const std::unordered_map& scope) { // NOLINT(*) for (auto& texture : scope) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index a7f4483ee2a90..72c1873ef4374 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -72,7 +72,6 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const MaxNode* op, std::ostream& os) final; void VisitExpr_(const AndNode* op, std::ostream& os) final; void VisitExpr_(const OrNode* op, std::ostream& os) final; - void VisitExpr_(const SelectNode* op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index c25b3c2c86eab..2ff16eef44a5b 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -173,8 +173,8 @@ def check_type_casting(ctx, n, dtype): lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" cond = "({} && {})".format(lcond, rcond) - select = "select({}, {}, {})".format(false_branch, true_branch, cond) - count = assembly.count(select) + ternary = "{} ? {} : {}".format(cond, true_branch, false_branch) + count = assembly.count(ternary) assert count == 1 fun(c)