Skip to content

Commit

Permalink
[SPARK-3363][SQL] Type Coercion should promote null to all other types.
Browse files Browse the repository at this point in the history
Type Coercion should support every type to have null value

Author: Daoyuan Wang <[email protected]>
Author: Michael Armbrust <[email protected]>

Closes apache#2246 from adrian-wang/spark3363-0 and squashes the following commits:

c6241de [Daoyuan Wang] minor code clean
595b417 [Daoyuan Wang] Merge pull request #2 from marmbrus/pr/2246
832e640 [Michael Armbrust] reduce code duplication
ef6f986 [Daoyuan Wang] make double boolean miss in jsonRDD compatibleType
c619f0a [Daoyuan Wang] Type Coercion should support every type to have null value
  • Loading branch information
adrian-wang authored and marmbrus committed Sep 10, 2014
1 parent a028330 commit f0c87dc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
val numericPrecedence =
Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
// Boolean is only wider than Void
val booleanPrecedence = Seq(NullType, BooleanType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
Expand All @@ -53,17 +65,6 @@ trait HiveTypeCoercion {
Division ::
Nil

trait TypeWidening {
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
}
}

/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
Expand Down Expand Up @@ -144,7 +145,8 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
object WidenTypes extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
Expand Down Expand Up @@ -352,7 +354,9 @@ trait HiveTypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
object CaseWhenCoercion extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {

val rules = new HiveTypeCoercion { }
import rules._

test("tightest common bound for numeric and boolean types") {
test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = WidenTypes.findTightestCommonType(t1, t2)
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = WidenTypes.findTightestCommonType(t2, t1)
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}

// Null
widenTest(NullType, NullType, Some(NullType))

// Boolean
widenTest(NullType, BooleanType, Some(BooleanType))
widenTest(BooleanType, BooleanType, Some(BooleanType))
Expand All @@ -60,12 +60,28 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(DoubleType, DoubleType, Some(DoubleType))

// Integral mixed with floating point.
widenTest(NullType, FloatType, Some(FloatType))
widenTest(NullType, DoubleType, Some(DoubleType))
widenTest(IntegerType, FloatType, Some(FloatType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))

// StringType
widenTest(NullType, StringType, Some(StringType))
widenTest(StringType, StringType, Some(StringType))
widenTest(IntegerType, StringType, None)
widenTest(LongType, StringType, None)

// TimestampType
widenTest(NullType, TimestampType, Some(TimestampType))
widenTest(TimestampType, TimestampType, Some(TimestampType))
widenTest(IntegerType, TimestampType, None)
widenTest(StringType, TimestampType, None)

// ComplexType
widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
}
51 changes: 22 additions & 29 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,38 +125,31 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
.contains(t2))

// If found return the widest common type, otherwise None
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)

if (returnType.isDefined) {
returnType.get
} else {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
HiveTypeCoercion.findTightestCommonType(t1, t2) match {
case Some(commonType) => commonType
case None =>
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
}
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
}
}

Expand Down

0 comments on commit f0c87dc

Please sign in to comment.