Skip to content

Commit

Permalink
[SPARK-23912][SQL] add array_distinct
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add array_distinct to remove duplicate value from the array.

## How was this patch tested?

Add unit tests

Author: Huaxin Gao <[email protected]>

Closes apache#21050 from huaxingao/spark-23912.
  • Loading branch information
huaxingao authored and ueshin committed Jun 21, 2018
1 parent 15747cf commit 9de11d3
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,20 @@ def array_remove(col, element):
return Column(sc._jvm.functions.array_remove(_to_java_column(col), element))


@since(2.4)
def array_distinct(col):
"""
Collection function: removes duplicate values from the array.
:param col: name of column or expression
>>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])
>>> df.select(array_distinct(df.data)).collect()
[Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ object FunctionRegistry {
expression[Flatten]("flatten"),
expression[ArrayRepeat]("array_repeat"),
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
CreateStruct.registryEntry,

// mask functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression)

override def prettyName: String = "array_remove"
}

/**
* Removes duplicate values from the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Removes duplicate values from the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, null, 3));
[1,2,3,null]
""", since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def dataType: DataType = child.dataType

@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
}
}

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

override def nullSafeEval(array: Any): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementTypeSupportEquals) {
new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
var foundNullElement = false
var pos = 0
for (i <- 0 until data.length) {
if (data(i) == null) {
if (!foundNullElement) {
foundNullElement = true
pos = pos + 1
}
} else {
var j = 0
var done = false
while (j <= i && !done) {
if (data(j) != null && ordering.equiv(data(j), data(i))) {
done = true
}
j = j + 1
}
if (i == j - 1) {
pos = pos + 1
}
}
}
new GenericArrayData(data.slice(0, pos))
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (array) => {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
val getValue1 = CodeGenerator.getValue(array, elementType, i)
val getValue2 = CodeGenerator.getValue(array, elementType, j)
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val hs = ctx.freshName("hs")
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
if (elementTypeSupportEquals) {
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $array.numElements(); $i ++) {
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add($getValue1);
| }
|}
|$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
} else {
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $array.numElements(); $i ++) {
| if ($array.isNullAt($i)) {
| if (!($foundNullElement)) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| $foundNullElement = true;
| }
| } else {
| int $j;
| for ($j = 0; $j < $i; $j ++) {
| if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) {
| break;
| }
| }
| if ($i == $j) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| }
| }
|}
|
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
}
})
}

private def setNull(
isPrimitive: Boolean,
foundNullElement: String,
distinctArray: String,
pos: String): String = {
val setNullValue =
if (!isPrimitive) {
s"$distinctArray[$pos] = null";
} else {
s"$distinctArray.setNullAt($pos)";
}

s"""
|if (!($foundNullElement)) {
| $setNullValue;
| $pos = $pos + 1;
| $foundNullElement = true;
|}
""".stripMargin
}

private def setNotNullValue(isPrimitive: Boolean,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
if (!isPrimitive) {
s"$distinctArray[$pos] = $getValue1";
} else {
s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
}
}

private def setValueForFastEval(
isPrimitive: Boolean,
hs: String,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|if (!($hs.contains($getValue1))) {
| $hs.add($getValue1);
| $setValue;
| $pos = $pos + 1;
|}
""".stripMargin
}

private def setValueForBruteForceEval(
isPrimitive: Boolean,
i: String,
j: String,
inputArray: String,
distinctArray: String,
pos: String,
getValue1: String,
isEqual: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|int $j;
|for ($j = 0; $j < $i; $j ++) {
| if (!$inputArray.isNullAt($j) && $isEqual) {
| break;
| }
|}
|if ($i == $j) {
| $setValue;
| $pos = $pos + 1;
|}
""".stripMargin
}

def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
size: String): String = {
val distinctArray = ctx.freshName("distinctArray")
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val pos = ctx.freshName("pos")
val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
val foundNullElement = ctx.freshName("foundNullElement")
val hs = ctx.freshName("hs")
val openHashSet = classOf[OpenHashSet[_]].getName
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
val setNullForNonPrimitive =
setNull(false, foundNullElement, distinctArray, pos)
if (elementTypeSupportEquals) {
val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
} else {
val setValueForBruteForce = setValueForBruteForceEval(
false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForBruteForce;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
}
} else {
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos)
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
val setValueForFast =
setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
|int $pos = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = $distinctArray;
""".stripMargin
}
}

override def prettyName: String = "array_distinct"
}
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}

test("Array Distinct") {
val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType))
val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType))
val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType))
val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType))
val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234),
ArrayType(DoubleType))
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
ArrayType(FloatType))

checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c"))
checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a"))
checkEvaluation(new ArrayDistinct(a4), Seq(null))
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))

// complex data types
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2),
null, Array[Byte](5, 6), null), ArrayType(BinaryType))

checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)))
checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null))
checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null,
Array[Byte](1, 2)))

val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2),
Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType)))
val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
ArrayType(ArrayType(IntegerType)))
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}
}
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3189,6 +3189,13 @@ object functions {
ArrayRemove(column.expr, Literal(element))
}

/**
* Removes duplicate values from the array.
* @group collection_funcs
* @since 2.4.0
*/
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Loading

0 comments on commit 9de11d3

Please sign in to comment.