Skip to content

Commit

Permalink
[SPARK-24313][SQL][BACKPORT-2.3] Fix collection operations' interpret…
Browse files Browse the repository at this point in the history
…ed evaluation for complex types

## What changes were proposed in this pull request?

The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains` and `GetMapValue`.

The PR fixes the behavior for all the datatypes.

## How was this patch tested?

added UT

Author: Marco Gaido <[email protected]>

Closes #21407 from mgaido91/SPARK-24313_2.3.
  • Loading branch information
mgaido91 authored and cloud-fan committed May 23, 2018
1 parent ed0060c commit ded6709
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -227,6 +227,9 @@ case class ArrayContains(left: Expression, right: Expression)

override def dataType: DataType = BooleanType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)

override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq.empty
case _ => left.dataType match {
Expand All @@ -243,7 +246,7 @@ case class ArrayContains(left: Expression, right: Expression)
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
}
}

Expand All @@ -256,7 +259,7 @@ case class ArrayContains(left: Expression, right: Expression)
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (v == value) {
} else if (ordering.equiv(v, value)) {
return true
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -301,7 +301,7 @@ case class GetMapValue(child: Expression, key: Expression)
var i = 0
var found = false
while (i < length && !found) {
if (keys.get(i, keyType) == ordinal) {
if (ordering.equiv(keys.get(i, keyType), ordinal)) {
found = true
} else {
i += 1
Expand Down Expand Up @@ -352,4 +352,15 @@ case class GetMapValue(child: Expression, key: Expression)
"""
})
}

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(keyType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)

// binary
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
ArrayType(BinaryType))
val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
ArrayType(BinaryType))
val be = Literal.create(Array[Byte](1, 2), BinaryType)
val nullBinary = Literal.create(null, BinaryType)

checkEvaluation(ArrayContains(b0, be), true)
checkEvaluation(ArrayContains(b1, be), false)
checkEvaluation(ArrayContains(b0, nullBinary), null)
checkEvaluation(ArrayContains(b2, be), null)
checkEvaluation(ArrayContains(b3, be), true)

// complex data types
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
checkEvaluation(ArrayContains(aa0, aae), true)
checkEvaluation(ArrayContains(aa1, aae), false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.analyze
comparePlans(Optimizer execute rel, expected)
}

test("SPARK-24313: support binary type as map keys in GetMapValue") {
val mb0 = Literal.create(
Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
MapType(BinaryType, StringType))
val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))

checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)

checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2270,4 +2270,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
}

test("SPARK-24313: access map with binary keys") {
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}
}

0 comments on commit ded6709

Please sign in to comment.