Skip to content

Commit

Permalink
[SPARK-20274][SQL] support compatible array element type in encoder
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a regression caused by SPARK-19716.

Before SPARK-19716, we will cast an array field to the expected array type. However, after SPARK-19716, the cast is removed, but we forgot to push the cast to the element level.

## How was this patch tested?

new regression tests

Author: Wenchen Fan <[email protected]>

Closes apache#17587 from cloud-fan/array.
  • Loading branch information
cloud-fan authored and Mingjie Tang committed Apr 18, 2017
1 parent d684bc8 commit f7c0a8a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object ScalaReflection extends ScalaReflection {
def deserializerFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
deserializerFor(tpe, None, walkedTypePath)
}

Expand Down Expand Up @@ -270,12 +270,14 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(_, elementNullable) = schemaFor(elementType)
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath

val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, Some(casted), newTypePath)
if (elementNullable) {
converter
} else {
Expand Down Expand Up @@ -305,12 +307,14 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(_, elementNullable) = schemaFor(elementType)
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath

val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, Some(casted), newTypePath)
if (elementNullable) {
converter
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects}
import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -2321,7 +2321,11 @@ class Analyzer(
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
val fromStr = from match {
case l: LambdaVariable => "array element"
case e => e.sql
}
throw new AnalysisException(s"Cannot up cast $fromStr from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ case class StringIntClass(a: String, b: Int)

case class ComplexClass(a: Long, b: StringLongClass)

case class PrimitiveArrayClass(arr: Array[Long])

case class ArrayClass(arr: Seq[StringIntClass])

case class NestedArrayClass(nestedArr: Array[ArrayClass])
Expand Down Expand Up @@ -66,6 +68,27 @@ class EncoderResolutionSuite extends PlanTest {
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
}

test("the real type is not compatible with encoder schema: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(StringType))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
s"""
|Cannot up cast array element from string to bigint as it may truncate
|The type path of the target object is:
|- array element class: "scala.Long"
|- field (class: "scala.Array", name: "arr")
|- root class: "org.apache.spark.sql.catalyst.encoders.PrimitiveArrayClass"
|You can either add an explicit cast to the input data or choose a higher precision type
""".stripMargin.trim + " of the field in the target object")
}

test("real type doesn't match encoder schema but they are compatible: array") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
Expand Down

0 comments on commit f7c0a8a

Please sign in to comment.