Skip to content

Commit

Permalink
[OpenCL] Avoid SelectNode ambiguous overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
mhyang-pllab committed May 27, 2022
1 parent 4a769c1 commit 60f68d2
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 13 deletions.
10 changes: 0 additions & 10 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,6 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) {
os << ")";
}

void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {
os << "select(";
PrintExpr(op->false_value, os);
os << ", ";
PrintExpr(op->true_value, os);
os << ", ";
PrintExpr(op->condition, os);
os << ")";
}

void CodeGenOpenCL::SetTextureScope(
const std::unordered_map<const VarNode*, std::string>& scope) { // NOLINT(*)
for (auto& texture : scope) {
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const MaxNode* op, std::ostream& os) final;
void VisitExpr_(const AndNode* op, std::ostream& os) final;
void VisitExpr_(const OrNode* op, std::ostream& os) final;
void VisitExpr_(const SelectNode* op, std::ostream& os) final;

private:
// whether enable fp16 and fp64 extension
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def check_type_casting(ctx, n, dtype):
lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
cond = "({} && {})".format(lcond, rcond)
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
count = assembly.count(select)
ternary = "{} ? {} : {}".format(cond, true_branch, false_branch)
count = assembly.count(ternary)
assert count == 1

fun(c)
Expand Down

0 comments on commit 60f68d2

Please sign in to comment.