Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sameeragarwal committed Apr 24, 2017
1 parent 9206702 commit c859b60
Showing 1 changed file with 28 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}

test("randomSplit on reordered partitions") {
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val data =
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)

assert(splits.length == 2, "wrong number of splits")
def testNonOverlappingSplits(data: DataFrame): Unit = {
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits")

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
// Verify that the splits don't overlap
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)

// Verify that the splits don't overlap
assert(splits(0).intersect(splits(1)).collect().isEmpty)
// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
}

// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val dataWithInts = sparkContext.parallelize(1 to 600, 2)
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Map(i -> i.toString)))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Array(Map(i -> i.toString))))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")

testNonOverlappingSplits(dataWithInts)
testNonOverlappingSplits(dataWithMaps)
testNonOverlappingSplits(dataWithArrayOfMaps)
}

test("pearson correlation") {
Expand Down

0 comments on commit c859b60

Please sign in to comment.