Skip to content

Commit

Permalink
"reviewer comment addressed"
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jun 10, 2014
1 parent f80f270 commit 0a9b3e3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 30 deletions.
11 changes: 6 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val numStDev = 10.0
val initialCount = this.count()

if (num < 0) {
Expand All @@ -406,15 +407,15 @@ abstract class RDD[T: ClassTag](
"sampling without replacement")
}

if (initialCount > Integer.MAX_VALUE - 1) {
val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
if (initialCount > Int.MaxValue - 1) {
val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSelected) {
throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " +
"5.0 * math.sqrt(Integer.MAX_VALUE)")
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}
}

fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement)
total = num

val rand = new Random(seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ private[spark] object SamplingUtils {
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* @param num sample size
* @param sampleSizeLowerBound sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
val fraction = num.toDouble / total
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = {
val fraction = sampleSizeLowerBound.toDouble / total
if (withReplacement) {
val numStDev = if (num < 12) 9 else 5
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 1e-4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ class SamplingUtilsSuite extends FunSuite{
val n = 100000

for (s <- 1 to 15) {
val frac = SamplingUtils.computeFraction(s, n, true)
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(20, 100, 1000)) {
val frac = SamplingUtils.computeFraction(s, n, true)
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = SamplingUtils.computeFraction(s, n, false)
val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
val binomial = new BinomialDistribution(n, frac)
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
}
Expand Down
5 changes: 0 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@
<artifactId>commons-codec</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.3</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
24 changes: 11 additions & 13 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,27 +366,25 @@ def takeSample(self, withReplacement, num, seed=None):

fraction = 0.0
total = 0
multiplier = 3.0
numStDev = 10.0
initialCount = self.count()
maxSelected = 0

if (num < 0):
if num < 0:
raise ValueError

if (initialCount == 0):
if initialCount == 0:
return list()

if (not withReplacement) and num > initialCount:
raise ValueError

if initialCount > sys.maxint - 1:
maxSelected = sys.maxint - 1
else:
maxSelected = initialCount
maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSelected:
raise ValueError

if num > initialCount and not withReplacement:
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
else:
fraction = self._computeFraction(num, initialCount, withReplacement)
total = num
fraction = self._computeFraction(num, initialCount, withReplacement)
total = num

samples = self.sample(withReplacement, fraction, seed).collect()

Expand Down

0 comments on commit 0a9b3e3

Please sign in to comment.