forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check sample size and move computeFraction
Check that the sample size is within supported range. Moved computeFraction int a private util class in util.random
- Loading branch information
Showing
4 changed files
with
108 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.util.random | ||
|
||
private[spark] object SamplingUtils { | ||
|
||
/** | ||
* Let p = num / total, where num is the sample size and total is the total number of | ||
* datapoints in the RDD. We're trying to compute q > p such that | ||
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), | ||
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), | ||
* i.e. the failure rate of not having a sufficiently large sample < 0.0001. | ||
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for | ||
* num > 12, but we need a slightly larger q (9 empirically determined). | ||
* - when sampling without replacement, we're drawing each datapoint with prob_i | ||
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success | ||
* rate, where success rate is defined the same as in sampling with replacement. | ||
* | ||
* @param num sample size | ||
* @param total size of RDD | ||
* @param withReplacement whether sampling with replacement | ||
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate | ||
*/ | ||
def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = { | ||
val fraction = num.toDouble / total | ||
if (withReplacement) { | ||
val numStDev = if (num < 12) 9 else 5 | ||
fraction + numStDev * math.sqrt(fraction / total) | ||
} else { | ||
val delta = 1e-4 | ||
val gamma = - math.log(delta) / total | ||
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.util.random | ||
|
||
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} | ||
import org.scalatest.FunSuite | ||
|
||
class SamplingUtilsSuite extends FunSuite{ | ||
|
||
test("computeFraction") { | ||
// test that the computed fraction guarantees enough datapoints | ||
// in the sample with a failure rate <= 0.0001 | ||
val n = 100000 | ||
|
||
for (s <- 1 to 15) { | ||
val frac = SamplingUtils.computeFraction(s, n, true) | ||
val poisson = new PoissonDistribution(frac * n) | ||
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
} | ||
for (s <- List(20, 100, 1000)) { | ||
val frac = SamplingUtils.computeFraction(s, n, true) | ||
val poisson = new PoissonDistribution(frac * n) | ||
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
} | ||
for (s <- List(1, 10, 100, 1000)) { | ||
val frac = SamplingUtils.computeFraction(s, n, false) | ||
val binomial = new BinomialDistribution(n, frac) | ||
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") | ||
} | ||
} | ||
} |