From 9de11d3f901bc206a33b9da3e7499bcd43e0142a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 21 Jun 2018 12:24:53 +0900 Subject: [PATCH] [SPARK-23912][SQL] add array_distinct ## What changes were proposed in this pull request? Add array_distinct to remove duplicate value from the array. ## How was this patch tested? Add unit tests Author: Huaxin Gao Closes #21050 from huaxingao/spark-23912. --- python/pyspark/sql/functions.py | 14 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 279 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 45 +++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++ 6 files changed, 368 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e6346691fb1d4..11b179fe26bfc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1999,6 +1999,20 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3700c63d817ea..4b09b9a7e75df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -433,6 +433,7 @@ object FunctionRegistry { expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c41..7c064a130ff35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85e692bdc4ef1..f377f9c8cd533 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8551058ec58ce..965dbb69c8efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3189,6 +3189,13 @@ object functions { ArrayRemove(column.expr, Literal(element)) } + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e2673..3dc696bd01eeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1216,6 +1216,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {