diff --git a/.gitignore b/.gitignore index fc167d012..2eade2116 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ *.iml *.class +# vscode config +.vscode + # Mobile Tools for Java (J2ME) .mtj.tmp/ diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index 94633cbd3..f41c8ab6a 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -25,7 +25,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.types.DateUnit - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ @@ -48,9 +47,9 @@ import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixMillis import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixSeconds import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixTimestamp import org.apache.arrow.vector.types.TimeUnit - import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils +import org.apache.spark.sql.internal.SQLConf /** * A version of add that supports columnar processing for longs. @@ -431,25 +430,25 @@ class ColumnarCast( } } else if (datatype == IntegerType) { val supported = - List(ByteType, ShortType, LongType, FloatType, DoubleType, DateType, DecimalType) + List(ByteType, ShortType, LongType, FloatType, DoubleType, DateType, DecimalType, StringType) if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) { throw new UnsupportedOperationException(s"${child.dataType} is not supported in castINT") } } else if (datatype == LongType) { - val supported = List(IntegerType, FloatType, DoubleType, DateType, DecimalType, TimestampType) + val supported = List(IntegerType, FloatType, DoubleType, DateType, DecimalType, TimestampType, StringType) if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) { throw new UnsupportedOperationException( s"${child.dataType} is not supported in castBIGINT") } } else if (datatype == FloatType) { - val supported = List(IntegerType, LongType, DoubleType, DecimalType) + val supported = List(IntegerType, LongType, DoubleType, DecimalType, StringType) if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) { throw new UnsupportedOperationException( s"${child.dataType} is not supported in castFLOAT4") } } else if (datatype == DoubleType) { - val supported = List(IntegerType, LongType, FloatType, DecimalType) + val supported = List(IntegerType, LongType, FloatType, DecimalType, StringType) if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) { throw new UnsupportedOperationException( @@ -479,6 +478,10 @@ class ColumnarCast( } override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + + // To compatible with Spark SQL ansi + val ansiEnabled = SQLConf.get.ansiEnabled + val (child_node, childType): (TreeNode, ArrowType) = child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) @@ -534,6 +537,13 @@ class ColumnarCast( Lists.newArrayList(round_down_node), new ArrowType.Int(64, true)) TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), toType) + case _: StringType => + // Compatible with spark ANSI + if (ansiEnabled) { + TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node0), toType) + } else { + TreeBuilder.makeFunction("castINTOrNull", Lists.newArrayList(child_node0), toType) + } case other => TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node0), toType) } @@ -547,6 +557,13 @@ class ColumnarCast( TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0), toType), TreeBuilder.makeLiteral(java.lang.Long.valueOf(1000L))), toType), toType) + case _: StringType => + // Compatible with spark ANSI + if (ansiEnabled) { + (TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0), toType), toType) + } else { + (TreeBuilder.makeFunction("castBIGINTOrNull", Lists.newArrayList(child_node0), toType), toType) + } case _ => (TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0), toType), toType) } @@ -558,13 +575,29 @@ class ColumnarCast( Lists.newArrayList(child_node0), new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)) TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(double_node), toType) + case _: StringType => + // Compatible with spark ANSI + if (ansiEnabled) { + TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node0), toType) + } else { + TreeBuilder.makeFunction("castFLOAT4OrNull", Lists.newArrayList(child_node0), toType) + } case other => TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node0), toType) } (funcNode, toType) } else if (dataType == DoubleType) { - val funcNode = - TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType) + val funcNode = child.dataType match { + case _: StringType => + // Compatible with spark ANSI + if (ansiEnabled) { + TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType) + } else { + TreeBuilder.makeFunction("castFLOAT8OrNull", Lists.newArrayList(child_node0), toType) + } + case other => + TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType) + } (funcNode, toType) } else if (dataType == DateType) { val funcNode = child.dataType match { diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index fe8d26a8a..1659b3ea1 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.Locale -import com.intel.oap.execution.{ColumnarSortExec, ColumnarSortMergeJoinExec} +import com.intel.oap.execution.{ColumnarBroadcastHashJoinExec, ColumnarSortExec, ColumnarSortMergeJoinExec} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -1049,7 +1049,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val right = Seq((1, 2), (3, 4)).toDF("c", "d") val df = left.join(right, pythonTestUDF(left("a")) === pythonTestUDF(right.col("c"))) - val joinNode = find(df.queryExecution.executedPlan)(_.isInstanceOf[BroadcastHashJoinExec]) + val joinNode = find(df.queryExecution.executedPlan)( + _.isInstanceOf[ColumnarBroadcastHashJoinExec]) assert(joinNode.isDefined) // There are two PythonUDFs which use attribute from left and right of join, individually. diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index d4259952f..3670d7c76 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -411,7 +411,7 @@ class AdaptiveQueryExecSuite val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 3) val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 2) + assert(bhj.size == 3) // A possible resulting query plan: // BroadcastHashJoin @@ -1360,7 +1360,7 @@ class AdaptiveQueryExecSuite val plan = dfRepartition.queryExecution.executedPlan // The top shuffle from repartition is optimized out. assert(!hasRepartitionShuffle(plan)) - val bhj = findTopLevelBroadcastHashJoin(plan) + val bhj = findTopLevelColumnarBroadcastHashJoin(plan) assert(bhj.length == 1) checkNumLocalShuffleReaders(plan, 1) // Probe side is coalesced. @@ -1374,7 +1374,7 @@ class AdaptiveQueryExecSuite val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan // The top shuffle from repartition is optimized out. assert(!hasRepartitionShuffle(planWithNum)) - val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) + val bhjWithNum = findTopLevelColumnarBroadcastHashJoin(planWithNum) assert(bhjWithNum.length == 1) checkNumLocalShuffleReaders(planWithNum, 1) // Probe side is not coalesced. @@ -1387,7 +1387,7 @@ class AdaptiveQueryExecSuite // The top shuffle from repartition is not optimized out, and this is the only shuffle that // does not have local shuffle reader. assert(hasRepartitionShuffle(planWithNum2)) - val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) + val bhjWithNum2 = findTopLevelColumnarBroadcastHashJoin(planWithNum2) assert(bhjWithNum2.length == 1) checkNumLocalShuffleReaders(planWithNum2, 1) val customReader2 = bhjWithNum2.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]) @@ -1407,7 +1407,7 @@ class AdaptiveQueryExecSuite val plan = dfRepartition.queryExecution.executedPlan // The top shuffle from repartition is optimized out. assert(!hasRepartitionShuffle(plan)) - val smj = findTopLevelSortMergeJoin(plan) + val smj = findTopLevelColumnarSortMergeJoin(plan) assert(smj.length == 1) // No skew join due to the repartition. assert(!smj.head.isSkewJoin) @@ -1423,7 +1423,7 @@ class AdaptiveQueryExecSuite val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan // The top shuffle from repartition is optimized out. assert(!hasRepartitionShuffle(planWithNum)) - val smjWithNum = findTopLevelSortMergeJoin(planWithNum) + val smjWithNum = findTopLevelColumnarSortMergeJoin(planWithNum) assert(smjWithNum.length == 1) // No skew join due to the repartition. assert(!smjWithNum.head.isSkewJoin) @@ -1439,7 +1439,7 @@ class AdaptiveQueryExecSuite val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan // The top shuffle from repartition is not optimized out. assert(hasRepartitionShuffle(planWithNum2)) - val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) + val smjWithNum2 = findTopLevelColumnarSortMergeJoin(planWithNum2) assert(smjWithNum2.length == 1) // Skew join can apply as the repartition is not optimized out. assert(smjWithNum2.head.isSkewJoin) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 39ecd9054..a98cd85ae 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -308,7 +308,11 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } else if (func_name.find("cast") != std::string::npos && func_name.compare("castDATE") != 0 && func_name.compare("castDECIMAL") != 0 && - func_name.compare("castDECIMALNullOnOverflow") != 0) { + func_name.compare("castDECIMALNullOnOverflow") != 0 && + func_name.compare("castINTOrNull") != 0 && + func_name.compare("castBIGINTOrNull") != 0 && + func_name.compare("castFLOAT4OrNull") != 0 && + func_name.compare("castFLOAT8OrNull") != 0) { codes_str_ = func_name + "_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; real_codes_str_ = codes_str_; @@ -488,6 +492,40 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += prepare_ss.str(); check_str_ = validity; header_list_.push_back(R"(#include "precompile/gandiva.h")"); + } else if (func_name.compare("castINTOrNull") == 0 || + func_name.compare("castBIGINTOrNull") == 0 || + func_name.compare("castFLOAT4OrNull") == 0 || + func_name.compare("castFLOAT8OrNull") == 0) { + codes_str_ = func_name + "_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + real_codes_str_ = codes_str_; + real_validity_str_ = validity; + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + + std::string func_str; + if (func_name.compare("castINTOrNull") == 0) { + func_str = " = std::stoi"; + } else if (func_name.compare("castBIGINTOrNull") == 0) { + func_str = " = std::stol"; + } else if (func_name.compare("castFLOAT4OrNull") == 0) { + func_str = " = std::stof"; + } else { + func_str = " = std::stod"; + } + prepare_ss << codes_str_ << func_str << "(" << child_visitor_list[0]->GetResult() + << ");" << std::endl; + prepare_ss << "}" << std::endl; + + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; } else if (func_name.compare("rescaleDECIMAL") == 0) { codes_str_ = func_name + "_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity";