From f7c0a8aece6f9def45cd357d6a24d572421f8e15 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Apr 2017 20:21:04 +0800 Subject: [PATCH] [SPARK-20274][SQL] support compatible array element type in encoder ## What changes were proposed in this pull request? This is a regression caused by SPARK-19716. Before SPARK-19716, we will cast an array field to the expected array type. However, after SPARK-19716, the cast is removed, but we forgot to push the cast to the element level. ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #17587 from cloud-fan/array. --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++------ .../sql/catalyst/analysis/Analyzer.scala | 8 +++++-- .../encoders/EncoderResolutionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 198122759e4ad..0c5a818f54f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -132,7 +132,7 @@ object ScalaReflection extends ScalaReflection { def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil deserializerFor(tpe, None, walkedTypePath) } @@ -270,12 +270,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { @@ -305,12 +307,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b0cdef70297cf..9816b33ae8dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2321,7 +2321,11 @@ class Analyzer( */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + val fromStr = from match { + case l: LambdaVariable => "array element" + case e => e.sql + } + throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index e5a3e1fd374dc..630e8a7990e7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,8 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class PrimitiveArrayClass(arr: Array[Long]) + case class ArrayClass(arr: Seq[StringIntClass]) case class NestedArrayClass(nestedArr: Array[ArrayClass]) @@ -66,6 +68,27 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(IntegerType)) + val array = new GenericArrayData(Array(1, 2, 3)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("the real type is not compatible with encoder schema: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(StringType)) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + s""" + |Cannot up cast array element from string to bigint as it may truncate + |The type path of the target object is: + |- array element class: "scala.Long" + |- field (class: "scala.Array", name: "arr") + |- root class: "org.apache.spark.sql.catalyst.encoders.PrimitiveArrayClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + test("real type doesn't match encoder schema but they are compatible: array") { val encoder = ExpressionEncoder[ArrayClass] val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))