Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3363][SQL] Type Coercion should support every type to have null value #2246

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
val numericPrecedence =
Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
// Boolean is only wider than Void
val booleanPrecedence = Seq(NullType, BooleanType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
Expand All @@ -53,17 +65,6 @@ trait HiveTypeCoercion {
Division ::
Nil

trait TypeWidening {
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
}
}

/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
Expand Down Expand Up @@ -144,7 +145,8 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
object WidenTypes extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
Expand Down Expand Up @@ -340,7 +342,9 @@ trait HiveTypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
object CaseWhenCoercion extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {

val rules = new HiveTypeCoercion { }
import rules._

test("tightest common bound for numeric and boolean types") {
test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = WidenTypes.findTightestCommonType(t1, t2)
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = WidenTypes.findTightestCommonType(t2, t1)
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}

// Null
widenTest(NullType, NullType, Some(NullType))

// Boolean
widenTest(NullType, BooleanType, Some(BooleanType))
widenTest(BooleanType, BooleanType, Some(BooleanType))
Expand All @@ -60,12 +60,28 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(DoubleType, DoubleType, Some(DoubleType))

// Integral mixed with floating point.
widenTest(NullType, FloatType, Some(FloatType))
widenTest(NullType, DoubleType, Some(DoubleType))
widenTest(IntegerType, FloatType, Some(FloatType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))

// StringType
widenTest(NullType, StringType, Some(StringType))
widenTest(StringType, StringType, Some(StringType))
widenTest(IntegerType, StringType, None)
widenTest(LongType, StringType, None)

// TimestampType
widenTest(NullType, TimestampType, Some(TimestampType))
widenTest(TimestampType, TimestampType, Some(TimestampType))
widenTest(IntegerType, TimestampType, None)
widenTest(StringType, TimestampType, None)

// ComplexType
widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
}
51 changes: 22 additions & 29 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,38 +125,31 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
.contains(t2))

// If found return the widest common type, otherwise None
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)

if (returnType.isDefined) {
returnType.get
} else {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
HiveTypeCoercion.findTightestCommonType(t1, t2) match {
case Some(commonType) => commonType
case None =>
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
}
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
}
}

Expand Down