Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-553] Complete the support to cast string type to types like int, bigint, float, double #552

Merged
merged 11 commits into from
Nov 17, 2021
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
*.iml
*.class

# vscode config
.vscode

# Mobile Tools for Java (J2ME)
.mtj.tmp/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.DateUnit

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
Expand All @@ -48,9 +47,9 @@ import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixMillis
import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixSeconds
import com.intel.oap.expression.ColumnarDateTimeExpressions.ColumnarUnixTimestamp
import org.apache.arrow.vector.types.TimeUnit

import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils
import org.apache.spark.sql.internal.SQLConf

/**
* A version of add that supports columnar processing for longs.
Expand Down Expand Up @@ -431,25 +430,25 @@ class ColumnarCast(
}
} else if (datatype == IntegerType) {
val supported =
List(ByteType, ShortType, LongType, FloatType, DoubleType, DateType, DecimalType)
List(ByteType, ShortType, LongType, FloatType, DoubleType, DateType, DecimalType, StringType)
if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(s"${child.dataType} is not supported in castINT")
}
} else if (datatype == LongType) {
val supported = List(IntegerType, FloatType, DoubleType, DateType, DecimalType, TimestampType)
val supported = List(IntegerType, FloatType, DoubleType, DateType, DecimalType, TimestampType, StringType)
if (supported.indexOf(child.dataType) == -1 &&
!child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in castBIGINT")
}
} else if (datatype == FloatType) {
val supported = List(IntegerType, LongType, DoubleType, DecimalType)
val supported = List(IntegerType, LongType, DoubleType, DecimalType, StringType)
if (supported.indexOf(child.dataType) == -1 && !child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in castFLOAT4")
}
} else if (datatype == DoubleType) {
val supported = List(IntegerType, LongType, FloatType, DecimalType)
val supported = List(IntegerType, LongType, FloatType, DecimalType, StringType)
if (supported.indexOf(child.dataType) == -1 &&
!child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
Expand Down Expand Up @@ -479,6 +478,10 @@ class ColumnarCast(
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {

// To compatible with Spark SQL ansi
val ansiEnabled = SQLConf.get.ansiEnabled

val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

Expand Down Expand Up @@ -534,6 +537,13 @@ class ColumnarCast(
Lists.newArrayList(round_down_node),
new ArrowType.Int(64, true))
TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), toType)
case _: StringType =>
// Compatible with spark ANSI
if (ansiEnabled) {
TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node0), toType)
} else {
TreeBuilder.makeFunction("castINTOrNull", Lists.newArrayList(child_node0), toType)
}
case other =>
TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node0), toType)
}
Expand All @@ -547,6 +557,13 @@ class ColumnarCast(
TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0),
toType),
TreeBuilder.makeLiteral(java.lang.Long.valueOf(1000L))), toType), toType)
case _: StringType =>
// Compatible with spark ANSI
if (ansiEnabled) {
(TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0), toType), toType)
} else {
(TreeBuilder.makeFunction("castBIGINTOrNull", Lists.newArrayList(child_node0), toType), toType)
}
case _ => (TreeBuilder.makeFunction("castBIGINT",
Lists.newArrayList(child_node0), toType), toType)
}
Expand All @@ -558,13 +575,29 @@ class ColumnarCast(
Lists.newArrayList(child_node0),
new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(double_node), toType)
case _: StringType =>
// Compatible with spark ANSI
if (ansiEnabled) {
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node0), toType)
} else {
TreeBuilder.makeFunction("castFLOAT4OrNull", Lists.newArrayList(child_node0), toType)
}
case other =>
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node0), toType)
}
(funcNode, toType)
} else if (dataType == DoubleType) {
val funcNode =
TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType)
val funcNode = child.dataType match {
case _: StringType =>
// Compatible with spark ANSI
if (ansiEnabled) {
TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType)
} else {
TreeBuilder.makeFunction("castFLOAT8OrNull", Lists.newArrayList(child_node0), toType)
}
case other =>
TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType)
}
(funcNode, toType)
} else if (dataType == DateType) {
val funcNode = child.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import java.util.Locale

import com.intel.oap.execution.{ColumnarSortExec, ColumnarSortMergeJoinExec}
import com.intel.oap.execution.{ColumnarBroadcastHashJoinExec, ColumnarSortExec, ColumnarSortMergeJoinExec}

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
Expand Down Expand Up @@ -1049,7 +1049,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val right = Seq((1, 2), (3, 4)).toDF("c", "d")
val df = left.join(right, pythonTestUDF(left("a")) === pythonTestUDF(right.col("c")))

val joinNode = find(df.queryExecution.executedPlan)(_.isInstanceOf[BroadcastHashJoinExec])
val joinNode = find(df.queryExecution.executedPlan)(
_.isInstanceOf[ColumnarBroadcastHashJoinExec])
assert(joinNode.isDefined)

// There are two PythonUDFs which use attribute from left and right of join, individually.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ class AdaptiveQueryExecSuite
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)
val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 2)
assert(bhj.size == 3)

// A possible resulting query plan:
// BroadcastHashJoin
Expand Down Expand Up @@ -1360,7 +1360,7 @@ class AdaptiveQueryExecSuite
val plan = dfRepartition.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(plan))
val bhj = findTopLevelBroadcastHashJoin(plan)
val bhj = findTopLevelColumnarBroadcastHashJoin(plan)
assert(bhj.length == 1)
checkNumLocalShuffleReaders(plan, 1)
// Probe side is coalesced.
Expand All @@ -1374,7 +1374,7 @@ class AdaptiveQueryExecSuite
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(planWithNum))
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
val bhjWithNum = findTopLevelColumnarBroadcastHashJoin(planWithNum)
assert(bhjWithNum.length == 1)
checkNumLocalShuffleReaders(planWithNum, 1)
// Probe side is not coalesced.
Expand All @@ -1387,7 +1387,7 @@ class AdaptiveQueryExecSuite
// The top shuffle from repartition is not optimized out, and this is the only shuffle that
// does not have local shuffle reader.
assert(hasRepartitionShuffle(planWithNum2))
val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
val bhjWithNum2 = findTopLevelColumnarBroadcastHashJoin(planWithNum2)
assert(bhjWithNum2.length == 1)
checkNumLocalShuffleReaders(planWithNum2, 1)
val customReader2 = bhjWithNum2.head.right.find(_.isInstanceOf[CustomShuffleReaderExec])
Expand All @@ -1407,7 +1407,7 @@ class AdaptiveQueryExecSuite
val plan = dfRepartition.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(plan))
val smj = findTopLevelSortMergeJoin(plan)
val smj = findTopLevelColumnarSortMergeJoin(plan)
assert(smj.length == 1)
// No skew join due to the repartition.
assert(!smj.head.isSkewJoin)
Expand All @@ -1423,7 +1423,7 @@ class AdaptiveQueryExecSuite
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
// The top shuffle from repartition is optimized out.
assert(!hasRepartitionShuffle(planWithNum))
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
val smjWithNum = findTopLevelColumnarSortMergeJoin(planWithNum)
assert(smjWithNum.length == 1)
// No skew join due to the repartition.
assert(!smjWithNum.head.isSkewJoin)
Expand All @@ -1439,7 +1439,7 @@ class AdaptiveQueryExecSuite
val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
// The top shuffle from repartition is not optimized out.
assert(hasRepartitionShuffle(planWithNum2))
val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
val smjWithNum2 = findTopLevelColumnarSortMergeJoin(planWithNum2)
assert(smjWithNum2.length == 1)
// Skew join can apply as the repartition is not optimized out.
assert(smjWithNum2.head.isSkewJoin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
} else if (func_name.find("cast") != std::string::npos &&
func_name.compare("castDATE") != 0 &&
func_name.compare("castDECIMAL") != 0 &&
func_name.compare("castDECIMALNullOnOverflow") != 0) {
func_name.compare("castDECIMALNullOnOverflow") != 0 &&
func_name.compare("castINTOrNull") != 0 &&
func_name.compare("castBIGINTOrNull") != 0 &&
func_name.compare("castFLOAT4OrNull") != 0 &&
func_name.compare("castFLOAT8OrNull") != 0) {
codes_str_ = func_name + "_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
real_codes_str_ = codes_str_;
Expand Down Expand Up @@ -488,6 +492,40 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
prepare_str_ += prepare_ss.str();
check_str_ = validity;
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("castINTOrNull") == 0 ||
func_name.compare("castBIGINTOrNull") == 0 ||
func_name.compare("castFLOAT4OrNull") == 0 ||
func_name.compare("castFLOAT8OrNull") == 0) {
codes_str_ = func_name + "_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
real_codes_str_ = codes_str_;
real_validity_str_ = validity;
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
<< std::endl;
prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck()
<< ";" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;

std::string func_str;
if (func_name.compare("castINTOrNull") == 0) {
func_str = " = std::stoi";
} else if (func_name.compare("castBIGINTOrNull") == 0) {
func_str = " = std::stol";
} else if (func_name.compare("castFLOAT4OrNull") == 0) {
func_str = " = std::stof";
} else {
func_str = " = std::stod";
}
prepare_ss << codes_str_ << func_str << "(" << child_visitor_list[0]->GetResult()
<< ");" << std::endl;
prepare_ss << "}" << std::endl;

for (int i = 0; i < 1; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
prepare_str_ += prepare_ss.str();
check_str_ = validity;
} else if (func_name.compare("rescaleDECIMAL") == 0) {
codes_str_ = func_name + "_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
Expand Down