diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala index bf9d010c3..b752e437b 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala @@ -275,6 +275,17 @@ object ColumnarExpressionConverter extends Logging { ss.len, convertBoundRefToAttrRef = convertBoundRefToAttrRef), expr) + case st: StringTranslate => + logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") + ColumnarTernaryOperator.create( + replaceWithColumnarExpression(st.srcExpr, attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + replaceWithColumnarExpression(st.matchingExpr, attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + replaceWithColumnarExpression(st.replaceExpr, attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + expr + ) case u: UnaryExpression => logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") if (!u.isInstanceOf[CheckOverflow] || !u.child.isInstanceOf[Divide]) { @@ -384,8 +395,11 @@ object ColumnarExpressionConverter extends Logging { containsSubquery(b.left) || containsSubquery(b.right) case s: String2TrimExpression => s.children.map(containsSubquery).exists(_ == true) + case st: StringTranslate => + st.children.map(containsSubquery).exists(_ == true) case regexp: RegExpReplace => - containsSubquery(regexp.subject) || containsSubquery(regexp.regexp) || containsSubquery(regexp.rep) || containsSubquery(regexp.pos) + containsSubquery(regexp.subject) || containsSubquery( + regexp.regexp) || containsSubquery(regexp.rep) || containsSubquery(regexp.pos) case expr => throw new UnsupportedOperationException( s" --> ${expr.getClass} | ${expr} is not currently supported.") diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarTernaryOperator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarTernaryOperator.scala index 49ca6b07c..b1bf3da05 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarTernaryOperator.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarTernaryOperator.scala @@ -110,15 +110,43 @@ class ColumnarStringSplit(child: Expression, regex: Expression, } } +class ColumnarStringTranslate(src: Expression, matchingExpr: Expression, + replaceExpr: Expression, original: Expression) + extends StringTranslate(src, matchingExpr, replaceExpr) with ColumnarExpression{ + buildCheck + + def buildCheck: Unit = { + val supportedTypes = List(StringType) + if (supportedTypes.indexOf(src.dataType) == -1) { + throw new UnsupportedOperationException(s"${src.dataType}" + + s" is not supported in ColumnarStringTranslate!") + } + } + + override def doColumnarCodeGen(args: java.lang.Object) : (TreeNode, ArrowType) = { + val (str_node, _): (TreeNode, ArrowType) = + src.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val (matchingExpr_node, _): (TreeNode, ArrowType) = + matchingExpr.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val (replaceExpr_node, _): (TreeNode, ArrowType) = + replaceExpr.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val resultType = new ArrowType.Utf8() + (TreeBuilder.makeFunction("translate", + Lists.newArrayList(str_node, matchingExpr_node, replaceExpr_node), resultType), resultType) + } +} + object ColumnarTernaryOperator { - def create(str: Expression, pos: Expression, len: Expression, + def create(src: Expression, arg1: Expression, arg2: Expression, original: Expression): Expression = original match { case ss: Substring => - new ColumnarSubString(str, pos, len, ss) + new ColumnarSubString(src, arg1, arg2, ss) // Currently not supported. // case a: StringSplit => // new ColumnarStringSplit(str, a.regex, a.limit, a) + case st: StringTranslate => + new ColumnarStringTranslate(src, arg1, arg2, st) case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 20d63da00..b52a522ec 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -282,6 +282,27 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_ss << "}" << std::endl; prepare_str_ += prepare_ss.str(); check_str_ = validity; + } else if (func_name.compare("translate") == 0) { + codes_str_ = func_name + "_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + real_codes_str_ = codes_str_; + real_validity_str_ = validity; + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + prepare_ss << codes_str_ << " = translate" + << "(" << child_visitor_list[0]->GetResult() << ", " + << child_visitor_list[1]->GetResult() << ", " + << child_visitor_list[2]->GetResult() << ");" << std::endl; + prepare_ss << "}" << std::endl; + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; } else if (func_name.compare("substr") == 0) { ss << child_visitor_list[0]->GetResult() << ".substr(" << "((" << child_visitor_list[1]->GetResult() << " - 1) < 0 ? 0 : (" diff --git a/native-sql-engine/cpp/src/precompile/gandiva.h b/native-sql-engine/cpp/src/precompile/gandiva.h index d722f799d..a3d5e53a5 100644 --- a/native-sql-engine/cpp/src/precompile/gandiva.h +++ b/native-sql-engine/cpp/src/precompile/gandiva.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "third_party/gandiva/decimal_ops.h" #include "third_party/gandiva/types.h" @@ -314,4 +315,30 @@ bool like(const std::string& data, const std::string& pattern) { std::string pcre_pattern = SqlLikePatternToPcre(pattern, 0); RE2 regex(pcre_pattern); return RE2::FullMatch(data, regex); +} + +const std::string translate(const std::string text, const std::string matching_str, + const std::string replace_str) { + char res[text.length()]; + std::unordered_map replace_map; + for (int i = 0; i < matching_str.length(); i++) { + if (i >= replace_str.length()) { + replace_map[matching_str[i]] = '\0'; + } else { + replace_map[matching_str[i]] = replace_str[i]; + } + } + int j = 0; + for (int i = 0; i < text.length(); i++) { + if (replace_map.find(text[i]) == replace_map.end()) { + res[j++] = text[i]; + continue; + } + char replace_char = replace_map[text[i]]; + if (replace_char != '\0') { + res[j++] = replace_char; + } + } + int out_len = j; + return std::string((char*)res, out_len); } \ No newline at end of file