Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23915][SQL] Add array_except function #21103

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private
* removed.
*
* The underlying implementation uses Scala compiler's specialization to generate optimized
* storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
* while incurring much less memory overhead. This can serve as building blocks for higher level
* data structures such as an optimized HashMap.
* storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's
* standard HashSet while incurring much less memory overhead. This can serve as building blocks
* for higher level data structures such as an optimized HashMap.
*
* This OpenHashSet is designed to serve as building blocks for higher level data structures
* such as an optimized hash map. Compared with standard hash set implementations, this class
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private
* to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
*/
@Private
class OpenHashSet[@specialized(Long, Int) T: ClassTag](
class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
extends Serializable {
Expand Down Expand Up @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
(new LongHasher).asInstanceOf[Hasher[T]]
} else if (mt == ClassTag.Int) {
(new IntHasher).asInstanceOf[Hasher[T]]
} else if (mt == ClassTag.Double) {
(new DoubleHasher).asInstanceOf[Hasher[T]]
} else if (mt == ClassTag.Float) {
(new FloatHasher).asInstanceOf[Hasher[T]]
} else {
new Hasher[T]
}
Expand Down Expand Up @@ -293,7 +297,7 @@ object OpenHashSet {
* A set of specialized hash function implementation to avoid boxing hash code computation
* in the specialized implementation of OpenHashSet.
*/
sealed class Hasher[@specialized(Long, Int) T] extends Serializable {
sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable {
def hash(o: T): Int = o.hashCode()
}

Expand All @@ -305,6 +309,17 @@ object OpenHashSet {
override def hash(o: Int): Int = o
}

class DoubleHasher extends Hasher[Double] {
override def hash(o: Double): Int = {
val bits = java.lang.Double.doubleToLongBits(o)
(bits ^ (bits >>> 32)).toInt
}
}

class FloatHasher extends Hasher[Float] {
override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o)
}

private def grow1(newSize: Int) {}
private def move1(oldPos: Int, newPos: Int) { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers {
assert(!set.contains(10000L))
}

test("primitive float") {
val set = new OpenHashSet[Float]
assert(set.size === 0)
assert(!set.contains(10.1F))
assert(!set.contains(50.5F))
assert(!set.contains(999.9F))
assert(!set.contains(10000.1F))

set.add(10.1F)
assert(set.size === 1)
assert(set.contains(10.1F))
assert(!set.contains(50.5F))
assert(!set.contains(999.9F))
assert(!set.contains(10000.1F))

set.add(50.5F)
assert(set.size === 2)
assert(set.contains(10.1F))
assert(set.contains(50.5F))
assert(!set.contains(999.9F))
assert(!set.contains(10000.1F))

set.add(999.9F)
assert(set.size === 3)
assert(set.contains(10.1F))
assert(set.contains(50.5F))
assert(set.contains(999.9F))
assert(!set.contains(10000.1F))

set.add(50.5F)
assert(set.size === 3)
assert(set.contains(10.1F))
assert(set.contains(50.5F))
assert(set.contains(999.9F))
assert(!set.contains(10000.1F))
}

test("primitive double") {
val set = new OpenHashSet[Double]
assert(set.size === 0)
assert(!set.contains(10.1D))
assert(!set.contains(50.5D))
assert(!set.contains(999.9D))
assert(!set.contains(10000.1D))

set.add(10.1D)
assert(set.size === 1)
assert(set.contains(10.1D))
assert(!set.contains(50.5D))
assert(!set.contains(999.9D))
assert(!set.contains(10000.1D))

set.add(50.5D)
assert(set.size === 2)
assert(set.contains(10.1D))
assert(set.contains(50.5D))
assert(!set.contains(999.9D))
assert(!set.contains(10000.1D))

set.add(999.9D)
assert(set.size === 3)
assert(set.contains(10.1D))
assert(set.contains(50.5D))
assert(set.contains(999.9D))
assert(!set.contains(10000.1D))

set.add(50.5D)
assert(set.size === 3)
assert(set.contains(10.1D))
assert(set.contains(50.5D))
assert(set.contains(999.9D))
assert(!set.contains(10000.1D))
}

test("non-primitive") {
val set = new OpenHashSet[String]
assert(set.size === 0)
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,25 @@ def array_union(col1, col2):
return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))


@ignore_unicode_prefix
@since(2.4)
def array_except(col1, col2):
"""
Collection function: returns an array of the elements in col1 but not in col2,
without duplicates.

:param col1: name of column containing array
:param col2: name of column containing array

>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
>>> df.select(array_except(df.c1, df.c2)).collect()
[Row(array_except(c1, c2)=[u'b'])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ object FunctionRegistry {
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[ArrayExcept]("array_except"),
expression[ArrayUnion]("array_union"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,14 +709,18 @@ trait ComplexTypeMergingExpression extends Expression {
@transient
lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)

override def dataType: DataType = {
def dataTypeCheck: Unit = {
require(
inputTypesForMerging.nonEmpty,
"The collection of input data types must not be empty.")
require(
TypeCoercion.haveSameType(inputTypesForMerging),
"All input types must be the same except nullable, containsNull, valueContainsNull flags." +
s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}")
}

override def dataType: DataType = {
dataTypeCheck
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
}
}
Expand Down
Loading