Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit #17751

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1726,15 +1726,23 @@ class Dataset[T] private[sql](
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
// MapType cannot be sorted.
val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType])
.map(SortOrder(_, Ascending)), global = false, logicalPlan)
// ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
// from the sort order.
val sortOrder = logicalPlan.output
.filter(attr => RowOrdering.isOrderable(attr.dataType))
.map(SortOrder(_, Ascending))
val plan = if (sortOrder.nonEmpty) {
Sort(sortOrder, global = false, logicalPlan)
} else {
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually discussed materialization in https://issues.apache.org/jira/browse/SPARK-12662 so that ticket should provide direct context.

cache()
logicalPlan
}
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}

test("randomSplit on reordered partitions") {
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val data =
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)

assert(splits.length == 2, "wrong number of splits")
def testNonOverlappingSplits(data: DataFrame): Unit = {
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits")

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
// Verify that the splits don't overlap
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)

// Verify that the splits don't overlap
assert(splits(0).intersect(splits(1)).collect().isEmpty)
// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
}

// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val dataWithInts = sparkContext.parallelize(1 to 600, 2)
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Map(i -> i.toString)))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Array(Map(i -> i.toString))))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")

testNonOverlappingSplits(dataWithInts)
testNonOverlappingSplits(dataWithMaps)
testNonOverlappingSplits(dataWithArrayOfMaps)
}

test("pearson correlation") {
Expand Down