Skip to content

Commit

Permalink
[SPARK-20302][SQL] Short circuit cast when from and to types are stru…
Browse files Browse the repository at this point in the history
…cturally the same

## What changes were proposed in this pull request?
When we perform a cast expression and the from and to types are structurally the same (having the same structure but different field names), we should be able to skip the actual cast.

## How was this patch tested?
Added unit tests for the newly introduced functions.

Author: Reynold Xin <[email protected]>

Closes #17614 from rxin/SPARK-20302.
  • Loading branch information
rxin committed Apr 12, 2017
1 parent 044f7ec commit ffc57b0
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
})
}

private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
case dt if dt == from => identity[Any]
case StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case CalendarIntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
private[this] def cast(from: DataType, to: DataType): Any => Any = {
// If the cast does not change the structure, then we don't really need to cast anything.
// We can return what the children return. Same thing should happen in the codegen path.
if (DataType.equalsStructurally(from, to)) {
identity
} else {
to match {
case dt if dt == from => identity[Any]
case StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case CalendarIntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType =>
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
}
}
}

private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)

protected override def nullSafeEval(input: Any): Any = cast(input)

override def genCode(ctx: CodegenContext): ExprCode = {
// If the cast does not change the structure, then we don't really need to cast anything.
// We can return what the children return. Same thing should happen in the interpreted path.
if (DataType.equalsStructurally(child.dataType, dataType)) {
child.genCode(ctx)
} else {
super.genCode(ctx)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,30 @@ object DataType {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

/**
* Returns true if the two data types share the same "shape", i.e. the types (including
* nullability) are the same, but the field names don't need to be the same.
*/
def equalsStructurally(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (left: ArrayType, right: ArrayType) =>
equalsStructurally(left.elementType, right.elementType) &&
left.containsNull == right.containsNull

case (left: MapType, right: MapType) =>
equalsStructurally(left.keyType, right.keyType) &&
equalsStructurally(left.valueType, right.valueType) &&
left.valueContainsNull == right.valueContainsNull

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields)
.forall { case (l, r) =>
equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
}

test("SPARK-20302 cast with same structure") {
val from = new StructType()
.add("a", IntegerType)
.add("b", new StructType().add("b1", LongType))

val to = new StructType()
.add("a1", IntegerType)
.add("b1", new StructType().add("b11", LongType))

val input = Row(10, Row(12L))

checkEvaluation(cast(Literal.create(input, from), to), input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite {
checkCatalogString(ArrayType(createStruct(40)))
checkCatalogString(MapType(IntegerType, StringType))
checkCatalogString(MapType(IntegerType, createStruct(40)))

def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
val testName = s"equalsStructurally: (from: $from, to: $to)"
test(testName) {
assert(DataType.equalsStructurally(from, to) === expected)
}
}

checkEqualsStructurally(BooleanType, BooleanType, true)
checkEqualsStructurally(IntegerType, IntegerType, true)
checkEqualsStructurally(IntegerType, LongType, false)
checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true)
checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false)

checkEqualsStructurally(
new StructType().add("f1", IntegerType),
new StructType().add("f2", IntegerType),
true)
checkEqualsStructurally(
new StructType().add("f1", IntegerType),
new StructType().add("f2", IntegerType, false),
false)

checkEqualsStructurally(
new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
true)
checkEqualsStructurally(
new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
false)
}

0 comments on commit ffc57b0

Please sign in to comment.